Improve interactive generation (support --tokenizer and --bpe)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/734

Differential Revision: D16377044

Pulled By: myleott

fbshipit-source-id: 37d5553d76aa7c653113fec089f59710281c31d7
This commit is contained in:
Myle Ott 2019-07-19 06:34:46 -07:00 committed by Facebook Github Bot
parent be5821b82b
commit 8af5554269
9 changed files with 56 additions and 44 deletions

View File

@ -19,23 +19,27 @@ flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
using ``tokenizer.perl`` from
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
Let's use :ref:`fairseq-interactive` to generate translations
interactively. Here, we use a beam size of 5:
Let's use :ref:`fairseq-interactive` to generate translations interactively.
Here, we use a beam size of 5 and preprocess the input with the Moses
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
remove the BPE continuation markers and detokenize the output.
.. code-block:: console
> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
Why is it rare to discover new marine mammal species?
S-0 Why is it rare to discover new marine mam@@ mal species ?
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the

View File

@ -244,11 +244,12 @@ $ SRC=de
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe \
| fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en

View File

@ -58,11 +58,12 @@ Next apply BPE on the fly and run generation for each expert:
$ BPEROOT=examples/translation/subword-nmt/
$ BPE_CODE=examples/translation/wmt17_en_de/code
$ for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \
fairseq-interactive data-bin/wmt17_en_de \
cat wmt14-en-de.extra_refs.tok \
| grep ^S | cut -f 2 \
| fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \

View File

@ -14,7 +14,7 @@ from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry(
'--tokenizer',
default='space',
default=None,
)

View File

@ -14,13 +14,13 @@ class MosesTokenizer(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('-s', '--source-lang', default='en', metavar='SRC',
parser.add_argument('--moses-source-lang', default='en', metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default='en', metavar='TARGET',
parser.add_argument('--moses-target-lang', default='en', metavar='TARGET',
help='target language')
parser.add_argument('--aggressive-dash-splits', action='store_true', default=False,
help='triggers dash split rules')
parser.add_argument('--no-escape', action='store_true', default=False,
parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
help='don\'t apply dash split rules')
parser.add_argument('--moses-no-escape', action='store_true', default=False,
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
# fmt: on
@ -28,17 +28,17 @@ class MosesTokenizer(object):
self.args = args
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(args.source_lang)
self.detok = MosesDetokenizer(args.target_lang)
self.tok = MosesTokenizer(args.moses_source_lang)
self.detok = MosesDetokenizer(args.moses_target_lang)
except ImportError:
raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
def encode(self, x: str) -> str:
return self.tok.tokenize(
x,
aggressive_dash_splits=self.args.aggressive_dash_splits,
aggressive_dash_splits=(not self.args.moses_no_dash_splits),
return_str=True,
escape=(not self.args.no_escape),
escape=(not self.args.moses_no_escape),
)
def decode(self, x: str) -> str:

View File

@ -17,7 +17,7 @@ class SpaceTokenizer(object):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:
return self.space_tok.sub(" ", x).strip().split()
return self.space_tok.sub(' ', x)
def decode(self, x: str) -> str:
return x

View File

@ -22,6 +22,8 @@ class SubwordNMTBPE(object):
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
codes = file_utils.cached_path(args.bpe_codes)
try:
from subword_nmt import apply_bpe

View File

@ -160,7 +160,7 @@ class SequenceGenerator(object):
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None
nonpad_idxs = None
if prefix_tokens is not None:
@ -618,10 +618,8 @@ class EnsembleModel(torch.nn.Module):
decoder_out[0].div_(temperature)
attn = decoder_out[1]
if type(attn) is dict:
attn = attn['attn']
attn = attn.get('attn', None)
if attn is not None:
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :]

View File

@ -15,6 +15,7 @@ import fileinput
import torch
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.data import transforms
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
@ -101,17 +102,23 @@ def main(args):
# Initialize generator
generator = task.build_generator(args)
# Hack to support GPT-2 BPE
if args.remove_bpe == 'gpt2':
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
decoder = get_encoder(
'fairseq/gpt2_bpe/encoder.json',
'fairseq/gpt2_bpe/vocab.bpe',
)
encode_fn = lambda x: ' '.join(map(str, decoder.encode(x)))
else:
decoder = None
encode_fn = lambda x: x
# Handle tokenization and BPE
tokenizer = transforms.build_tokenizer(args)
bpe = transforms.build_bpe(args)
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
@ -162,8 +169,7 @@ def main(args):
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if decoder is not None:
hypo_str = decoder.decode(map(int, hypo_str.strip().split()))
hypo_str = decode_fn(hypo_str)
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
id,