2018-02-28 01:09:42 +03:00
|
|
|
#!/usr/bin/env python3 -u
|
2017-09-15 03:22:43 +03:00
|
|
|
# Copyright (c) 2017-present, Facebook, Inc.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the LICENSE file in
|
|
|
|
# the root directory of this source tree. An additional grant of patent rights
|
|
|
|
# can be found in the PATENTS file in the same directory.
|
2018-09-03 18:57:53 +03:00
|
|
|
"""
|
|
|
|
Translate pre-processed data with a trained model.
|
|
|
|
"""
|
2017-09-15 03:22:43 +03:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2019-03-12 00:15:10 +03:00
|
|
|
from fairseq import bleu, options, progress_bar, tasks, utils
|
2017-09-15 03:22:43 +03:00
|
|
|
from fairseq.meters import StopwatchMeter, TimeMeter
|
2019-01-16 23:05:22 +03:00
|
|
|
from fairseq.utils import import_user_module
|
2017-09-15 03:22:43 +03:00
|
|
|
|
|
|
|
|
2018-02-28 01:09:42 +03:00
|
|
|
def main(args):
|
2018-04-02 17:13:07 +03:00
|
|
|
assert args.path is not None, '--path required for generation!'
|
2018-06-12 20:39:41 +03:00
|
|
|
assert not args.sampling or args.nbest == args.beam, \
|
|
|
|
'--sampling requires --nbest to be equal to --beam'
|
|
|
|
assert args.replace_unk is None or args.raw_text, \
|
|
|
|
'--replace-unk requires a raw text dataset (--raw-text)'
|
2018-05-15 20:54:01 +03:00
|
|
|
|
2019-01-16 23:05:22 +03:00
|
|
|
import_user_module(args)
|
|
|
|
|
2018-05-15 20:54:01 +03:00
|
|
|
if args.max_tokens is None and args.max_sentences is None:
|
|
|
|
args.max_tokens = 12000
|
2017-09-15 03:22:43 +03:00
|
|
|
print(args)
|
|
|
|
|
|
|
|
use_cuda = torch.cuda.is_available() and not args.cpu
|
|
|
|
|
2018-06-12 20:39:41 +03:00
|
|
|
# Load dataset splits
|
|
|
|
task = tasks.setup_task(args)
|
|
|
|
task.load_dataset(args.gen_subset)
|
|
|
|
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
|
|
|
|
|
|
|
|
# Set dictionaries
|
2019-02-16 03:34:51 +03:00
|
|
|
try:
|
|
|
|
src_dict = getattr(task, 'source_dictionary', None)
|
|
|
|
except NotImplementedError:
|
|
|
|
src_dict = None
|
2018-06-12 20:39:41 +03:00
|
|
|
tgt_dict = task.target_dictionary
|
2017-10-18 00:29:31 +03:00
|
|
|
|
|
|
|
# Load ensemble
|
2018-05-10 05:19:48 +03:00
|
|
|
print('| loading model(s) from {}'.format(args.path))
|
2019-01-05 07:00:49 +03:00
|
|
|
models, _model_args = utils.load_ensemble_for_inference(
|
|
|
|
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
|
|
|
|
)
|
2017-09-15 03:22:43 +03:00
|
|
|
|
2017-10-18 00:29:31 +03:00
|
|
|
# Optimize ensemble for generation
|
2017-09-15 03:22:43 +03:00
|
|
|
for model in models:
|
2018-07-12 15:13:28 +03:00
|
|
|
model.make_generation_fast_(
|
|
|
|
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
|
|
|
need_attn=args.print_alignment,
|
|
|
|
)
|
2018-06-20 03:27:00 +03:00
|
|
|
if args.fp16:
|
|
|
|
model.half()
|
2019-02-22 21:06:22 +03:00
|
|
|
if use_cuda:
|
|
|
|
model.cuda()
|
2017-09-25 20:21:13 +03:00
|
|
|
|
|
|
|
# Load alignment dictionary for unknown word replacement
|
2017-10-30 19:21:30 +03:00
|
|
|
# (None if no unknown word replacement, empty if no path to align dictionary)
|
|
|
|
align_dict = utils.load_align_dict(args.replace_unk)
|
2018-05-15 20:54:01 +03:00
|
|
|
|
2018-06-12 20:39:41 +03:00
|
|
|
# Load dataset (possibly sharded)
|
2018-08-30 19:17:33 +03:00
|
|
|
itr = task.get_batch_iterator(
|
2018-06-12 20:39:41 +03:00
|
|
|
dataset=task.dataset(args.gen_subset),
|
2018-05-15 20:54:01 +03:00
|
|
|
max_tokens=args.max_tokens,
|
|
|
|
max_sentences=args.max_sentences,
|
2018-08-30 19:17:33 +03:00
|
|
|
max_positions=utils.resolve_max_positions(
|
|
|
|
task.max_positions(),
|
|
|
|
*[model.max_positions() for model in models]
|
|
|
|
),
|
2018-06-12 20:39:41 +03:00
|
|
|
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
2019-03-13 01:08:47 +03:00
|
|
|
required_batch_size_multiple=args.required_batch_size_multiple,
|
2018-06-12 20:39:41 +03:00
|
|
|
num_shards=args.num_shards,
|
|
|
|
shard_id=args.shard_id,
|
2019-01-05 07:00:49 +03:00
|
|
|
num_workers=args.num_workers,
|
2018-06-12 20:39:41 +03:00
|
|
|
).next_epoch_itr(shuffle=False)
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
# Initialize generator
|
|
|
|
gen_timer = StopwatchMeter()
|
2019-02-22 21:06:22 +03:00
|
|
|
generator = task.build_generator(args)
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
# Generate and compute BLEU score
|
2019-01-30 19:58:37 +03:00
|
|
|
if args.sacrebleu:
|
|
|
|
scorer = bleu.SacrebleuScorer()
|
|
|
|
else:
|
|
|
|
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
|
2017-10-30 19:21:30 +03:00
|
|
|
num_sentences = 0
|
2018-03-06 01:12:27 +03:00
|
|
|
has_target = True
|
2018-02-28 01:09:42 +03:00
|
|
|
with progress_bar.build_progress_bar(args, itr) as t:
|
2017-10-30 19:21:30 +03:00
|
|
|
wps_meter = TimeMeter()
|
2019-02-22 21:06:22 +03:00
|
|
|
for sample in t:
|
|
|
|
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
|
|
|
if 'net_input' not in sample:
|
|
|
|
continue
|
|
|
|
|
|
|
|
prefix_tokens = None
|
|
|
|
if args.prefix_size > 0:
|
|
|
|
prefix_tokens = sample['target'][:, :args.prefix_size]
|
|
|
|
|
|
|
|
gen_timer.start()
|
|
|
|
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
|
|
|
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
|
|
|
|
gen_timer.stop(num_generated_tokens)
|
|
|
|
|
|
|
|
for i, sample_id in enumerate(sample['id'].tolist()):
|
|
|
|
has_target = sample['target'] is not None
|
|
|
|
|
|
|
|
# Remove padding
|
|
|
|
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
|
|
|
|
target_tokens = None
|
2018-05-30 20:45:41 +03:00
|
|
|
if has_target:
|
2019-02-22 21:06:22 +03:00
|
|
|
target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
|
2017-11-06 18:58:30 +03:00
|
|
|
|
2019-02-22 21:06:22 +03:00
|
|
|
# Either retrieve the original sentences or regenerate them from tokens.
|
|
|
|
if align_dict is not None:
|
|
|
|
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
|
|
|
|
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
|
|
|
|
else:
|
|
|
|
if src_dict is not None:
|
|
|
|
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
|
|
|
else:
|
|
|
|
src_str = ""
|
|
|
|
if has_target:
|
|
|
|
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
|
2017-10-30 19:21:30 +03:00
|
|
|
|
|
|
|
if not args.quiet:
|
2019-02-22 21:06:22 +03:00
|
|
|
if src_dict is not None:
|
|
|
|
print('S-{}\t{}'.format(sample_id, src_str))
|
|
|
|
if has_target:
|
|
|
|
print('T-{}\t{}'.format(sample_id, target_str))
|
|
|
|
|
|
|
|
# Process top predictions
|
|
|
|
for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
|
|
|
|
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
|
|
|
hypo_tokens=hypo['tokens'].int().cpu(),
|
|
|
|
src_str=src_str,
|
|
|
|
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
|
|
|
|
align_dict=align_dict,
|
|
|
|
tgt_dict=tgt_dict,
|
|
|
|
remove_bpe=args.remove_bpe,
|
|
|
|
)
|
|
|
|
|
|
|
|
if not args.quiet:
|
|
|
|
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
|
|
|
|
print('P-{}\t{}'.format(
|
2018-07-10 19:57:53 +03:00
|
|
|
sample_id,
|
2019-02-22 21:06:22 +03:00
|
|
|
' '.join(map(
|
|
|
|
lambda x: '{:.4f}'.format(x),
|
|
|
|
hypo['positional_scores'].tolist(),
|
|
|
|
))
|
2018-07-10 19:57:53 +03:00
|
|
|
))
|
2017-10-30 19:21:30 +03:00
|
|
|
|
2019-02-22 21:06:22 +03:00
|
|
|
if args.print_alignment:
|
|
|
|
print('A-{}\t{}'.format(
|
|
|
|
sample_id,
|
|
|
|
' '.join(map(lambda x: str(utils.item(x)), alignment))
|
|
|
|
))
|
|
|
|
|
|
|
|
# Score only the top hypothesis
|
|
|
|
if has_target and i == 0:
|
|
|
|
if align_dict is not None or args.remove_bpe is not None:
|
|
|
|
# Convert back to tokens for evaluation with unk replacement and/or without BPE
|
2019-02-28 20:15:35 +03:00
|
|
|
target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
|
2019-02-22 21:06:22 +03:00
|
|
|
if hasattr(scorer, 'add_string'):
|
|
|
|
scorer.add_string(target_str, hypo_str)
|
|
|
|
else:
|
|
|
|
scorer.add(target_tokens, hypo_tokens)
|
|
|
|
|
|
|
|
wps_meter.update(num_generated_tokens)
|
2017-11-12 03:25:10 +03:00
|
|
|
t.log({'wps': round(wps_meter.avg)})
|
2019-02-22 21:06:22 +03:00
|
|
|
num_sentences += sample['nsentences']
|
2017-10-30 19:21:30 +03:00
|
|
|
|
2018-05-17 15:57:34 +03:00
|
|
|
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
|
2018-05-30 20:45:41 +03:00
|
|
|
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
|
2018-03-06 01:12:27 +03:00
|
|
|
if has_target:
|
|
|
|
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
|
2019-03-12 00:15:10 +03:00
|
|
|
return scorer
|
2017-09-15 03:22:43 +03:00
|
|
|
|
|
|
|
|
2019-02-05 18:46:44 +03:00
|
|
|
def cli_main():
|
2018-02-28 01:09:42 +03:00
|
|
|
parser = options.get_generation_parser()
|
2018-05-30 19:06:56 +03:00
|
|
|
args = options.parse_args_and_arch(parser)
|
2018-02-28 01:09:42 +03:00
|
|
|
main(args)
|
2019-02-05 18:46:44 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
cli_main()
|