Improve interactive generation (support --tokenizer and --bpe)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/734

Differential Revision: D16377044

Pulled By: myleott

fbshipit-source-id: 37d5553d76aa7c653113fec089f59710281c31d7
This commit is contained in:
Myle Ott 2019-07-19 06:34:46 -07:00 committed by Facebook Github Bot
parent be5821b82b
commit 8af5554269
9 changed files with 56 additions and 44 deletions

View File

@ -19,23 +19,27 @@ flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
using ``tokenizer.perl`` from using ``tokenizer.perl`` from
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__. `mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
Let's use :ref:`fairseq-interactive` to generate translations Let's use :ref:`fairseq-interactive` to generate translations interactively.
interactively. Here, we use a beam size of 5: Here, we use a beam size of 5 and preprocess the input with the Moses
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
remove the BPE continuation markers and detokenize the output.
.. code-block:: console .. code-block:: console
> MODEL_DIR=wmt14.en-fr.fconv-py > MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \ > fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \ --path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr --beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt | loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types | [en] dictionary: 44206 types
| [fr] dictionary: 44463 types | [fr] dictionary: 44463 types
| Type the input sentence and press return: | Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ? Why is it rare to discover new marine mammal species?
O Why is it rare to discover new marine mam@@ mal species ? S-0 Why is it rare to discover new marine mam@@ mal species ?
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ? H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288 P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
This generation script produces three types of outputs: a line prefixed This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the with *O* is a copy of the original source sentence; *H* is the

View File

@ -244,11 +244,12 @@ $ SRC=de
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \ $ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \ | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe > iwslt17.test.${SRC}-en.${SRC}.bpe
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \ $ cat iwslt17.test.${SRC}-en.${SRC}.bpe \
--task multilingual_translation --source-lang ${SRC} --target-lang en \ | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \ --task multilingual_translation --source-lang ${SRC} --target-lang en \
--buffer 2000 --batch-size 128 \ --path checkpoints/multilingual_transformer/checkpoint_best.pt \
--beam 5 --remove-bpe=sentencepiece \ --buffer 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys > iwslt17.test.${SRC}-en.en.sys
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \ $ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en

View File

@ -58,11 +58,12 @@ Next apply BPE on the fly and run generation for each expert:
$ BPEROOT=examples/translation/subword-nmt/ $ BPEROOT=examples/translation/subword-nmt/
$ BPE_CODE=examples/translation/wmt17_en_de/code $ BPE_CODE=examples/translation/wmt17_en_de/code
$ for EXPERT in $(seq 0 2); do \ $ for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \ cat wmt14-en-de.extra_refs.tok \
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \ | grep ^S | cut -f 2 \
fairseq-interactive data-bin/wmt17_en_de \ | fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \ --path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \ --beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \ --buffer-size 500 --max-tokens 6000 \
--task translation_moe \ --task translation_moe \
--method hMoElp --mean-pool-gating-network \ --method hMoElp --mean-pool-gating-network \

View File

@ -14,7 +14,7 @@ from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry(
'--tokenizer', '--tokenizer',
default='space', default=None,
) )

View File

@ -14,13 +14,13 @@ class MosesTokenizer(object):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
# fmt: off # fmt: off
parser.add_argument('-s', '--source-lang', default='en', metavar='SRC', parser.add_argument('--moses-source-lang', default='en', metavar='SRC',
help='source language') help='source language')
parser.add_argument('-t', '--target-lang', default='en', metavar='TARGET', parser.add_argument('--moses-target-lang', default='en', metavar='TARGET',
help='target language') help='target language')
parser.add_argument('--aggressive-dash-splits', action='store_true', default=False, parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
help='triggers dash split rules') help='don\'t apply dash split rules')
parser.add_argument('--no-escape', action='store_true', default=False, parser.add_argument('--moses-no-escape', action='store_true', default=False,
help='don\'t perform HTML escaping on apostrophy, quotes, etc.') help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
# fmt: on # fmt: on
@ -28,17 +28,17 @@ class MosesTokenizer(object):
self.args = args self.args = args
try: try:
from sacremoses import MosesTokenizer, MosesDetokenizer from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(args.source_lang) self.tok = MosesTokenizer(args.moses_source_lang)
self.detok = MosesDetokenizer(args.target_lang) self.detok = MosesDetokenizer(args.moses_target_lang)
except ImportError: except ImportError:
raise ImportError('Please install Moses tokenizer with: pip install sacremoses') raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
def encode(self, x: str) -> str: def encode(self, x: str) -> str:
return self.tok.tokenize( return self.tok.tokenize(
x, x,
aggressive_dash_splits=self.args.aggressive_dash_splits, aggressive_dash_splits=(not self.args.moses_no_dash_splits),
return_str=True, return_str=True,
escape=(not self.args.no_escape), escape=(not self.args.moses_no_escape),
) )
def decode(self, x: str) -> str: def decode(self, x: str) -> str:

View File

@ -17,7 +17,7 @@ class SpaceTokenizer(object):
self.space_tok = re.compile(r"\s+") self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str: def encode(self, x: str) -> str:
return self.space_tok.sub(" ", x).strip().split() return self.space_tok.sub(' ', x)
def decode(self, x: str) -> str: def decode(self, x: str) -> str:
return x return x

View File

@ -22,6 +22,8 @@ class SubwordNMTBPE(object):
# fmt: on # fmt: on
def __init__(self, args): def __init__(self, args):
if args.bpe_codes is None:
raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
codes = file_utils.cached_path(args.bpe_codes) codes = file_utils.cached_path(args.bpe_codes)
try: try:
from subword_nmt import apply_bpe from subword_nmt import apply_bpe

View File

@ -160,7 +160,7 @@ class SequenceGenerator(object):
scores_buf = scores.clone() scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None attn, attn_buf = None, None
nonpad_idxs = None nonpad_idxs = None
if prefix_tokens is not None: if prefix_tokens is not None:
@ -618,10 +618,8 @@ class EnsembleModel(torch.nn.Module):
decoder_out[0].div_(temperature) decoder_out[0].div_(temperature)
attn = decoder_out[1] attn = decoder_out[1]
if type(attn) is dict: if type(attn) is dict:
attn = attn['attn'] attn = attn.get('attn', None)
if attn is not None: if attn is not None:
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :] attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :] probs = probs[:, -1, :]

View File

@ -15,6 +15,7 @@ import fileinput
import torch import torch
from fairseq import checkpoint_utils, options, tasks, utils from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.data import transforms
Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
@ -101,17 +102,23 @@ def main(args):
# Initialize generator # Initialize generator
generator = task.build_generator(args) generator = task.build_generator(args)
# Hack to support GPT-2 BPE # Handle tokenization and BPE
if args.remove_bpe == 'gpt2': tokenizer = transforms.build_tokenizer(args)
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder bpe = transforms.build_bpe(args)
decoder = get_encoder(
'fairseq/gpt2_bpe/encoder.json', def encode_fn(x):
'fairseq/gpt2_bpe/vocab.bpe', if tokenizer is not None:
) x = tokenizer.encode(x)
encode_fn = lambda x: ' '.join(map(str, decoder.encode(x))) if bpe is not None:
else: x = bpe.encode(x)
decoder = None return x
encode_fn = lambda x: x
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
@ -162,8 +169,7 @@ def main(args):
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
if decoder is not None: hypo_str = decode_fn(hypo_str)
hypo_str = decoder.decode(map(int, hypo_str.strip().split()))
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str)) print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format( print('P-{}\t{}'.format(
id, id,