From 49177c99c45f7d6e99a8f1500d16396e2d7b4519 Mon Sep 17 00:00:00 2001 From: Nathan Ng Date: Thu, 15 Aug 2019 10:03:44 -0700 Subject: [PATCH] Backward reranking public (#667) Summary: Implementation of noisy channel model reranking for release with paper Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/667 Reviewed By: michaelauli Differential Revision: D15901665 Pulled By: nng555 fbshipit-source-id: 2de2c518be8e5828ffad72db3e741b0940623373 --- .gitignore | 3 + eval_lm.py | 6 +- examples/__init__.py | 10 + examples/noisychannel/README.md | 72 +++ examples/noisychannel/__init__.py | 8 + examples/noisychannel/rerank.py | 283 ++++++++++ examples/noisychannel/rerank_generate.py | 246 +++++++++ examples/noisychannel/rerank_options.py | 128 +++++ examples/noisychannel/rerank_score_bw.py | 95 ++++ examples/noisychannel/rerank_score_lm.py | 48 ++ examples/noisychannel/rerank_tune.py | 85 +++ examples/noisychannel/rerank_utils.py | 646 +++++++++++++++++++++++ fairseq/models/transformer.py | 2 +- 13 files changed, 1629 insertions(+), 3 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/noisychannel/README.md create mode 100644 examples/noisychannel/__init__.py create mode 100644 examples/noisychannel/rerank.py create mode 100644 examples/noisychannel/rerank_generate.py create mode 100644 examples/noisychannel/rerank_options.py create mode 100644 examples/noisychannel/rerank_score_bw.py create mode 100644 examples/noisychannel/rerank_score_lm.py create mode 100644 examples/noisychannel/rerank_tune.py create mode 100644 examples/noisychannel/rerank_utils.py diff --git a/.gitignore b/.gitignore index 7e4a2d412..84ae18d95 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ fairseq/modules/*_layer/*_backward.cu # data data-bin/ + +# reranking +examples/reranking/rerank_data diff --git a/eval_lm.py b/eval_lm.py index e2da64fc1..febed5ac8 100644 --- a/eval_lm.py +++ b/eval_lm.py @@ -146,8 +146,9 @@ def main(parsed_args): hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) - for hypos_i in hypos: + for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] + sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() @@ -199,7 +200,8 @@ def main(parsed_args): is_bpe = False w = '' if args.output_word_probs: - print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) + print(str(int(sample_id)) + " " + + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..906098c1e --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +__version__ = '0.7.2' + +import examples.noisychannel # noqa diff --git a/examples/noisychannel/README.md b/examples/noisychannel/README.md new file mode 100644 index 000000000..a5dd0b9de --- /dev/null +++ b/examples/noisychannel/README.md @@ -0,0 +1,72 @@ +# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019) +This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts. + +## Citation: +```bibtex +@inproceedings{yee2018simple, + title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation}, + author = {Kyra Yee and Yann Dauphin and Michael Auli}, + booktitle = {Conference on Empirical Methods in Natural Language Processing}, + year = {2019}, +} +``` + +## Pre-trained Models: + +Model | Description | Download +---|---|--- +`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2) +`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2) +`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2) + +Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2) + +## Example usage + +``` +mkdir rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example + +beam=50 +num_trials=1000 +fw_name=fw_model_ex +bw_name=bw_model_ex +lm_name=lm_ex +data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe +data_dir_name=wmt17 +lm=rerank_example/lm/checkpoint_best.pt +lm_bpe_code=rerank_example/lm/bpe32k.code +lm_dict=rerank_example/lm/dict.txt +batch_size=32 +bw=rerank_example/backward_en2de.pt +fw=rerank_example/forward_de2en.pt + +# reranking with P(T|S) P(S|T) and P(T) +python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \ + --lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \ + --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \ + --backwards1 --weight2 1 \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name + +# reranking with P(T|S) and P(T) +python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \ + --lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \ + --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model1 $fw \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model1-name $fw_name --gen-model-name $fw_name + +# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead. +python examples/noisychannel/rerank.py $data_dir \ + --lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \ + --data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name +``` + diff --git a/examples/noisychannel/__init__.py b/examples/noisychannel/__init__.py new file mode 100644 index 000000000..b10ddbd81 --- /dev/null +++ b/examples/noisychannel/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from .rerank_options import * diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py new file mode 100644 index 000000000..c17d64b4a --- /dev/null +++ b/examples/noisychannel/rerank.py @@ -0,0 +1,283 @@ +import rerank_utils +import rerank_generate +import rerank_score_bw +import rerank_score_lm +from fairseq import bleu, options +from fairseq.data import dictionary +from examples.noisychannel import rerank_options +from multiprocessing import Pool + +import math +import numpy as np + + +def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize): + + print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) + gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) + dict = dictionary.Dictionary() + scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + + ordered_hypos = {} + ordered_targets = {} + + for shard_id in range(len(bitext1_lst)): + bitext1 = bitext1_lst[shard_id] + bitext2 = bitext2_lst[shard_id] + gen_output = gen_output_lst[shard_id] + lm_res = lm_res_lst[shard_id] + + total = len(bitext1.rescore_source.keys()) + source_lst = [] + hypo_lst = [] + score_lst = [] + reference_lst = [] + j = 1 + best_score = -math.inf + + for i in range(total): + # length is measured in terms of words, not bpe tokens, since models may not share the same bpe + target_len = len(bitext1.rescore_hypo[i].split()) + + if lm_res is not None: + lm_score = lm_res.score[i] + else: + lm_score = 0 + + if bitext2 is not None: + bitext2_score = bitext2.rescore_score[i] + bitext2_backwards = bitext2.backwards + else: + bitext2_score = None + bitext2_backwards = None + + score = rerank_utils.get_score(a, b, c, target_len, + bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, + lenpen=lenpen, src_len=bitext1.source_lengths[i], + tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, + bitext2_backwards=bitext2_backwards, normalize=normalize) + + if score > best_score: + best_score = score + best_hypo = bitext1.rescore_hypo[i] + + if j == gen_output.num_hypos[i] or j == args.num_rescore: + j = 1 + hypo_lst.append(best_hypo) + score_lst.append(best_score) + source_lst.append(bitext1.rescore_source[i]) + reference_lst.append(bitext1.rescore_target[i]) + + best_score = -math.inf + best_hypo = "" + else: + j += 1 + + gen_keys = list(sorted(gen_output.no_bpe_target.keys())) + + for key in range(len(gen_keys)): + if args.prefix_len is None: + assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ + ("pred and rescore hypo mismatch: i: " + str(key) + ", " + str(hypo_lst[key]) + str(gen_keys[key]) + + str(gen_output.no_bpe_hypo[key])) + sys_tok = dict.encode_line(hypo_lst[key]) + ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) + scorer.add(ref_tok, sys_tok) + + else: + full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) + sys_tok = dict.encode_line(full_hypo) + ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) + scorer.add(ref_tok, sys_tok) + + # if only one set of hyper parameters is provided, write the predictions to a file + if write_hypos: + # recover the orinal ids from n best list generation + for key in range(len(gen_output.no_bpe_target)): + if args.prefix_len is None: + assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ + "pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) + ordered_hypos[gen_keys[key]] = hypo_lst[key] + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] + + else: + full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) + ordered_hypos[gen_keys[key]] = full_hypo + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] + + # write the hypos in the original order from nbest list generation + if args.num_shards == (len(bitext1_lst)): + with open(target_outfile, 'w') as t: + with open(hypo_outfile, 'w') as h: + for key in range(len(ordered_hypos)): + t.write(ordered_targets[key]) + h.write(ordered_hypos[key]) + + res = scorer.result_string(4) + if write_hypos: + print(res) + score = rerank_utils.parse_bleu_scoring(res) + return score + + +def match_target_hypo(args, target_outfile, hypo_outfile): + """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" + if len(args.weight1) == 1: + res = score_target_hypo(args, args.weight1[0], args.weight2[0], + args.weight3[0], args.lenpen[0], target_outfile, + hypo_outfile, True, args.normalize) + rerank_scores = [res] + else: + print("launching pool") + with Pool(32) as p: + rerank_scores = p.starmap(score_target_hypo, + [(args, args.weight1[i], args.weight2[i], args.weight3[i], + args.lenpen[i], target_outfile, hypo_outfile, + False, args.normalize) for i in range(len(args.weight1))]) + + if len(rerank_scores) > 1: + best_index = np.argmax(rerank_scores) + best_score = rerank_scores[best_index] + print("best score", best_score) + print("best lenpen", args.lenpen[best_index]) + print("best weight1", args.weight1[best_index]) + print("best weight2", args.weight2[best_index]) + print("best weight3", args.weight3[best_index]) + return args.lenpen[best_index], args.weight1[best_index], \ + args.weight2[best_index], args.weight3[best_index], best_score + + else: + return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0] + + +def load_score_files(args): + if args.all_shards: + shard_ids = list(range(args.num_shards)) + else: + shard_ids = [args.shard_id] + + gen_output_lst = [] + bitext1_lst = [] + bitext2_lst = [] + lm_res1_lst = [] + + for shard_id in shard_ids: + using_nbest = args.nbest_list is not None + pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir = \ + rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, + args.gen_model_name, shard_id, args.num_shards, args.sampling, + args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) + + rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None + rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None + + score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1) + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2) + if args.language_model is not None: + lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) + + # get gen output + predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, + nbest=using_nbest, prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac) + + if rerank1_is_gen: + bitext1 = gen_output + else: + bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1, + args.remove_bpe, args.prefix_len, args.target_prefix_frac, + args.source_prefix_frac) + + if args.score_model2 is not None or args.nbest_list is not None: + if rerank2_is_gen: + bitext2 = gen_output + else: + bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2, + args.remove_bpe, args.prefix_len, args.target_prefix_frac, + args.source_prefix_frac) + + assert bitext2.source_lengths == bitext1.source_lengths, \ + "source lengths for rescoring models do not match" + assert bitext2.target_lengths == bitext1.target_lengths, \ + "target lengths for rescoring models do not match" + else: + if args.diff_bpe: + assert args.score_model2 is None + bitext2 = gen_output + else: + bitext2 = None + + if args.language_model is not None: + lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len, + args.remove_bpe, args.target_prefix_frac) + else: + lm_res1 = None + + gen_output_lst.append(gen_output) + bitext1_lst.append(bitext1) + bitext2_lst.append(bitext2) + lm_res1_lst.append(lm_res1) + return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst + + +def rerank(args): + if type(args.lenpen) is not list: + args.lenpen = [args.lenpen] + if type(args.weight1) is not list: + args.weight1 = [args.weight1] + if type(args.weight2) is not list: + args.weight2 = [args.weight2] + if type(args.weight3) is not list: + args.weight3 = [args.weight3] + if args.all_shards: + shard_ids = list(range(args.num_shards)) + else: + shard_ids = [args.shard_id] + + for shard_id in shard_ids: + pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir = \ + rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, + args.gen_model_name, shard_id, args.num_shards, args.sampling, + args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) + rerank_generate.gen_and_reprocess_nbest(args) + rerank_score_bw.score_bw(args) + rerank_score_lm.score_lm(args) + + if args.write_hypos is None: + write_targets = pre_gen+"/matched_targets" + write_hypos = pre_gen+"/matched_hypos" + else: + write_targets = args.write_hypos+"_targets" + args.gen_subset + write_hypos = args.write_hypos+"_hypos" + args.gen_subset + + if args.all_shards: + write_targets += "_all_shards" + write_hypos += "_all_shards" + + best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \ + match_target_hypo(args, write_targets, write_hypos) + + return best_lenpen, best_weight1, best_weight2, best_weight3, best_score + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + rerank(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/noisychannel/rerank_generate.py b/examples/noisychannel/rerank_generate.py new file mode 100644 index 000000000..27dcdb599 --- /dev/null +++ b/examples/noisychannel/rerank_generate.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 -u +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import rerank_utils +import os +import subprocess +from examples.noisychannel import rerank_options +from fairseq import options +import generate +import preprocess +from contextlib import redirect_stdout + +""" +Generate n-best translations using a trained model. +""" + +def gen_and_reprocess_nbest(args): + if args.score_dict_dir is None: + args.score_dict_dir = args.data + if args.prefix_len is not None: + assert args.right_to_left1 is False, "prefix length not compatible with right to left models" + assert args.right_to_left2 is False, "prefix length not compatible with right to left models" + + if args.nbest_list is not None: + assert args.score_model2 is None + + if args.backwards1: + scorer1_src = args.target_lang + scorer1_tgt = args.source_lang + else: + scorer1_src = args.source_lang + scorer1_tgt = args.target_lang + + store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name + if not os.path.exists(store_data): + os.makedirs(store_data) + + pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir = \ + rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, + args.gen_model_name, args.shard_id, args.num_shards, + args.sampling, args.prefix_len, args.target_prefix_frac, + args.source_prefix_frac) + assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported" + assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported" + assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \ + "target prefix frac and target prefix len incompatible" + + # make directory to store generation results + if not os.path.exists(pre_gen): + os.makedirs(pre_gen) + + rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None + rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None + + if args.nbest_list is not None: + rerank2_is_gen = True + + # make directories to store preprossed nbest list for reranking + if not os.path.exists(left_to_right_preprocessed_dir): + os.makedirs(left_to_right_preprocessed_dir) + if not os.path.exists(right_to_left_preprocessed_dir): + os.makedirs(right_to_left_preprocessed_dir) + if not os.path.exists(lm_preprocessed_dir): + os.makedirs(lm_preprocessed_dir) + if not os.path.exists(backwards_preprocessed_dir): + os.makedirs(backwards_preprocessed_dir) + + score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1) + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2) + + predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + + using_nbest = args.nbest_list is not None + + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + + else: + if not os.path.isfile(predictions_bpe_file): + print("STEP 1: generate predictions using the p(T|S) model with bpe") + print(args.data) + param1 = [args.data, + "--path", args.gen_model, + "--shard-id", str(args.shard_id), + "--num-shards", str(args.num_shards), + "--nbest", str(args.num_rescore), + "--batch-size", str(args.batch_size), + "--beam", str(args.num_rescore), + "--max-sentences", str(args.num_rescore), + "--gen-subset", args.gen_subset, + "--source-lang", args.source_lang, + "--target-lang", args.target_lang] + if args.sampling: + param1 += ["--sampling"] + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, param1) + + print(input_args) + with open(predictions_bpe_file, 'w') as f: + with redirect_stdout(f): + generate.main(input_args) + + gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, + nbest=using_nbest, prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac) + + if args.diff_bpe: + rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, + gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang, + pre_gen+"/target_gen_bpe."+args.target_lang, + pre_gen+"/reference_gen_bpe."+args.target_lang) + bitext_bpe = args.rescore_bpe_code + bpe_src_param = ["-c", bitext_bpe, + "--input", pre_gen+"/source_gen_bpe."+args.source_lang, + "--output", pre_gen+"/rescore_data."+args.source_lang] + bpe_tgt_param = ["-c", bitext_bpe, + "--input", pre_gen+"/target_gen_bpe."+args.target_lang, + "--output", pre_gen+"/rescore_data."+args.target_lang] + + subprocess.call(["python", + os.path.join(os.path.dirname(__file__), + "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, + shell=False) + + subprocess.call(["python", + os.path.join(os.path.dirname(__file__), + "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param, + shell=False) + + if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \ + (args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen): + print("STEP 2: process the output of generate.py so we have clean text files with the translations") + + rescore_file = "/rescore_data" + if args.prefix_len is not None: + prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len) + if args.target_prefix_frac is not None: + target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac) + if args.source_prefix_frac is not None: + source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac) + + if not args.right_to_left1 or not args.right_to_left2: + if not args.diff_bpe: + rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, + pre_gen+rescore_file+"."+args.source_lang, + pre_gen+rescore_file+"."+args.target_lang, + pre_gen+"/reference_file", bpe_symbol=args.remove_bpe) + if args.prefix_len is not None: + bw_rescore_file = prefix_len_rescore_file + rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, + pre_gen+prefix_len_rescore_file+"."+args.source_lang, + pre_gen+prefix_len_rescore_file+"."+args.target_lang, + pre_gen+"/reference_file", prefix_len=args.prefix_len, + bpe_symbol=args.remove_bpe) + elif args.target_prefix_frac is not None: + bw_rescore_file = target_prefix_frac_rescore_file + rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, + pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang, + pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang, + pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, + target_prefix_frac=args.target_prefix_frac) + else: + bw_rescore_file = rescore_file + + if args.source_prefix_frac is not None: + fw_rescore_file = source_prefix_frac_rescore_file + rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, + pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang, + pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang, + pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, + source_prefix_frac=args.source_prefix_frac) + else: + fw_rescore_file = rescore_file + + if args.right_to_left1 or args.right_to_left2: + rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, + pre_gen+"/right_to_left_rescore_data."+args.source_lang, + pre_gen+"/right_to_left_rescore_data."+args.target_lang, + pre_gen+"/right_to_left_reference_file", + right_to_left=True, bpe_symbol=args.remove_bpe) + + print("STEP 3: binarize the translations") + if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen: + + if args.backwards1 or args.backwards2: + if args.backwards_score_dict_dir is not None: + bw_dict = args.backwards_score_dict_dir + else: + bw_dict = args.score_dict_dir + bw_preprocess_param = ["--source-lang", scorer1_src, + "--target-lang", scorer1_tgt, + "--trainpref", pre_gen+bw_rescore_file, + "--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", + "--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", + "--destdir", backwards_preprocessed_dir] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(bw_preprocess_param) + preprocess.main(input_args) + + preprocess_param = ["--source-lang", scorer1_src, + "--target-lang", scorer1_tgt, + "--trainpref", pre_gen+fw_rescore_file, + "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", + "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", + "--destdir", left_to_right_preprocessed_dir] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_param) + preprocess.main(input_args) + + if args.right_to_left1 or args.right_to_left2: + preprocess_param = ["--source-lang", scorer1_src, + "--target-lang", scorer1_tgt, + "--trainpref", pre_gen+"/right_to_left_rescore_data", + "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", + "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", + "--destdir", right_to_left_preprocessed_dir] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_param) + preprocess.main(input_args) + + return gen_output + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + gen_and_reprocess_nbest(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/noisychannel/rerank_options.py b/examples/noisychannel/rerank_options.py new file mode 100644 index 000000000..1f8c748b9 --- /dev/null +++ b/examples/noisychannel/rerank_options.py @@ -0,0 +1,128 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from fairseq import options + + +def get_reranking_parser(default_task='translation'): + parser = options.get_parser('Generation and reranking', default_task) + add_reranking_args(parser) + return parser + + +def get_tuning_parser(default_task='translation'): + parser = options.get_parser('Reranking tuning', default_task) + add_reranking_args(parser) + add_tuning_args(parser) + return parser + + +def add_reranking_args(parser): + group = parser.add_argument_group("Reranking") + # fmt: off + group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True, + help='path to first model or ensemble of models for rescoring') + group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False, + help='path to second model or ensemble of models for rescoring') + group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10, + help='the number of candidate hypothesis to rescore') + group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128, + help='batch size for generating the nbest list') + group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'], + help='data subset to generate (train, valid, test)') + group.add_argument('--gen-model', default=None, metavar='FILE', + help='the model to generate translations') + group.add_argument('-b1', '--backwards1', action='store_true', + help='whether or not the first model group is backwards') + group.add_argument('-b2', '--backwards2', action='store_true', + help='whether or not the second model group is backwards') + group.add_argument('-a', '--weight1', default=1, nargs='+', type=float, + help='the weight(s) of the first model') + group.add_argument('-b', '--weight2', default=1, nargs='+', type=float, + help='the weight(s) of the second model, or the gen model if using nbest from interactive.py') + group.add_argument('-c', '--weight3', default=1, nargs='+', type=float, + help='the weight(s) of the third model') + + # lm arguments + group.add_argument('-lm', '--language-model', default=None, metavar='FILE', + help='language model for target language to rescore translations') + group.add_argument('--lm-dict', default=None, metavar='FILE', + help='the dict of the language model for the target language') + group.add_argument('--lm-name', default=None, + help='the name of the language model for the target language') + group.add_argument('--lm-bpe-code', default=None, metavar='FILE', + help='the bpe code for the language model for the target language') + group.add_argument('--data-dir-name', default=None, + help='name of data directory') + group.add_argument('--lenpen', default=1, nargs='+', type=float, + help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') + group.add_argument('--score-dict-dir', default=None, + help='the directory with dictionaries for the scoring models') + group.add_argument('--right-to-left1', action='store_true', + help='whether the first model group is a right to left model') + group.add_argument('--right-to-left2', action='store_true', + help='whether the second model group is a right to left model') + group.add_argument('--remove-bpe', default='@@ ', + help='the bpe symbol, used for the bitext and LM') + group.add_argument('--prefix-len', default=None, type=int, + help='the length of the target prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--sampling', action='store_true', + help='use sampling instead of beam search for generating n best list') + group.add_argument('--diff-bpe', action='store_true', + help='bpe for rescoring and nbest list not the same') + group.add_argument('--rescore-bpe-code', default=None, + help='bpe code for rescoring models') + group.add_argument('--nbest-list', default=None, + help='use predefined nbest list in interactive.py format') + group.add_argument('--write-hypos', default=None, + help='filename prefix to write hypos to') + group.add_argument('--ref-translation', default=None, + help='reference translation to use with nbest list from interactive.py') + group.add_argument('--backwards-score-dict-dir', default=None, + help='the directory with dictionaries for the backwards model,' + 'if None then it is assumed the fw and backwards models share dictionaries') + + # extra scaling args + group.add_argument('--gen-model-name', default=None, + help='the name of the models that generated the nbest list') + group.add_argument('--model1-name', default=None, + help='the name of the set for model1 group ') + group.add_argument('--model2-name', default=None, + help='the name of the set for model2 group') + group.add_argument('--shard-id', default=0, type=int, + help='the id of the shard to generate') + group.add_argument('--num-shards', default=1, type=int, + help='the number of shards to generate across') + group.add_argument('--all-shards', action='store_true', + help='use all shards') + group.add_argument('--target-prefix-frac', default=None, type=float, + help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--source-prefix-frac', default=None, type=float, + help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--normalize', action='store_true', + help='whether to normalize by src and target len') + + return group + + +def add_tuning_args(parser): + group = parser.add_argument_group("Tuning") + + group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float, + help='lower bound of search space') + group.add_argument('--upper-bound', default=[3], nargs='+', type=float, + help='upper bound of search space') + group.add_argument('--tune-param', default=['lenpen'], nargs='+', + choices=['lenpen', 'weight1', 'weight2', 'weight3'], + help='the parameter(s) to tune') + group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'], + help='the subset to tune on ') + group.add_argument('--num-trials', default=1000, type=int, + help='number of trials to do for random search') + group.add_argument('--share-weights', action='store_true', + help='share weight2 and weight 3') + return group diff --git a/examples/noisychannel/rerank_score_bw.py b/examples/noisychannel/rerank_score_bw.py new file mode 100644 index 000000000..c1558022a --- /dev/null +++ b/examples/noisychannel/rerank_score_bw.py @@ -0,0 +1,95 @@ +import rerank_utils +import os +from fairseq import options +from examples.noisychannel import rerank_options +from contextlib import redirect_stdout +import generate + + +def score_bw(args): + if args.backwards1: + scorer1_src = args.target_lang + scorer1_tgt = args.source_lang + else: + scorer1_src = args.source_lang + scorer1_tgt = args.target_lang + + if args.score_model2 is not None: + if args.backwards2: + scorer2_src = args.target_lang + scorer2_tgt = args.source_lang + else: + scorer2_src = args.source_lang + scorer2_tgt = args.target_lang + + rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None + rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None + + pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir = \ + rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, + args.gen_model_name, args.shard_id, args.num_shards, + args.sampling, args.prefix_len, args.target_prefix_frac, + args.source_prefix_frac) + + score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1) + + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2) + + if args.right_to_left1: + rerank_data1 = right_to_left_preprocessed_dir + elif args.backwards1: + rerank_data1 = backwards_preprocessed_dir + else: + rerank_data1 = left_to_right_preprocessed_dir + + gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] + if not rerank1_is_gen and not os.path.isfile(score1_file): + print("STEP 4: score the translations for model 1") + + model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt] + gen_model1_param = [rerank_data1] + gen_param + model_param1 + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) + + with open(score1_file, 'w') as f: + with redirect_stdout(f): + generate.main(input_args) + + if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen: + print("STEP 4: score the translations for model 2") + + if args.right_to_left2: + rerank_data2 = right_to_left_preprocessed_dir + elif args.backwards2: + rerank_data2 = backwards_preprocessed_dir + else: + rerank_data2 = left_to_right_preprocessed_dir + + model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt] + gen_model2_param = [rerank_data2] + gen_param + model_param2 + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) + + with open(score2_file, 'w') as f: + with redirect_stdout(f): + generate.main(input_args) + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + score_bw(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/noisychannel/rerank_score_lm.py b/examples/noisychannel/rerank_score_lm.py new file mode 100644 index 000000000..e35e1da6c --- /dev/null +++ b/examples/noisychannel/rerank_score_lm.py @@ -0,0 +1,48 @@ +import rerank_utils +import os +from fairseq import options +from examples.noisychannel import rerank_options + + +def score_lm(args): + using_nbest = args.nbest_list is not None + pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir = \ + rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, + args.gen_model_name, args.shard_id, args.num_shards, + args.sampling, args.prefix_len, args.target_prefix_frac, + args.source_prefix_frac) + + predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + + gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest) + + if args.language_model is not None: + lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) + + if args.language_model is not None and not os.path.isfile(lm_score_file): + print("STEP 4.5: language modeling for P(T)") + if args.lm_bpe_code is None: + bpe_status = "no bpe" + elif args.lm_bpe_code == "shared": + bpe_status = "shared" + else: + bpe_status = "different" + + rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen, + args.lm_dict, args.lm_name, args.language_model, + args.lm_bpe_code, 128, lm_score_file, args.target_lang, + args.source_lang, prefix_len=args.prefix_len) + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + score_lm(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/noisychannel/rerank_tune.py b/examples/noisychannel/rerank_tune.py new file mode 100644 index 000000000..805d87579 --- /dev/null +++ b/examples/noisychannel/rerank_tune.py @@ -0,0 +1,85 @@ +import rerank +import argparse +import numpy as np +import random +from examples.noisychannel import rerank_options +from fairseq import options + + +def random_search(args): + param_values = [] + tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3'] + initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3] + for i, elem in enumerate(initial_params): + if type(elem) is not list: + initial_params[i] = [elem] + else: + initial_params[i] = elem + + tune_parameters = args.tune_param.copy() + for i in range(len(args.tune_param)): + assert args.upper_bound[i] >= args.lower_bound[i] + index = tuneable_parameters.index(args.tune_param[i]) + del tuneable_parameters[index] + del initial_params[index] + + tune_parameters += tuneable_parameters + param_values += initial_params + random.seed(args.seed) + + random_params = np.array([[random.uniform(args.lower_bound[i], args.upper_bound[i]) + for i in range(len(args.tune_param))] + for k in range(args.num_trials)]) + set_params = np.array([[initial_params[i][0] + for i in range(len(tuneable_parameters))] + for k in range(args.num_trials)]) + random_params = np.concatenate((random_params, set_params), 1) + + rerank_args = vars(args).copy() + if args.nbest_list: + rerank_args['gen_subset'] = 'test' + else: + rerank_args['gen_subset'] = args.tune_subset + + for k in range(len(tune_parameters)): + rerank_args[tune_parameters[k]] = list(random_params[:, k]) + + if args.share_weights: + k = tune_parameters.index('weight2') + rerank_args['weight3'] = list(random_params[:, k]) + + rerank_args = argparse.Namespace(**rerank_args) + best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args) + rerank_args = vars(args).copy() + rerank_args['lenpen'] = [best_lenpen] + rerank_args['weight1'] = [best_weight1] + rerank_args['weight2'] = [best_weight2] + rerank_args['weight3'] = [best_weight3] + + # write the hypothesis from the valid set from the best trial + + if args.gen_subset != "valid": + rerank_args['gen_subset'] = "valid" + rerank_args = argparse.Namespace(**rerank_args) + rerank.rerank(rerank_args) + + # test with the best hyperparameters on gen subset + rerank_args = vars(args).copy() + rerank_args['gen_subset'] = args.gen_subset + rerank_args['lenpen'] = [best_lenpen] + rerank_args['weight1'] = [best_weight1] + rerank_args['weight2'] = [best_weight2] + rerank_args['weight3'] = [best_weight3] + rerank_args = argparse.Namespace(**rerank_args) + rerank.rerank(rerank_args) + + +def cli_main(): + parser = rerank_options.get_tuning_parser() + args = options.parse_args_and_arch(parser) + + random_search(args) + + +if __name__ == '__main__': + cli_main() diff --git a/examples/noisychannel/rerank_utils.py b/examples/noisychannel/rerank_utils.py new file mode 100644 index 000000000..9b8bb7bec --- /dev/null +++ b/examples/noisychannel/rerank_utils.py @@ -0,0 +1,646 @@ +import subprocess +import os +import re +from fairseq import options +import eval_lm +import preprocess +from contextlib import redirect_stdout +import math + + +def reprocess(fle): + # takes in a file of generate.py translation generate_output + # returns a source dict and hypothesis dict, where keys are the ID num (as a string) + # and values and the corresponding source and translation. There may be several translations + # per source, so the values for hypothesis_dict are lists. + # parses output of generate.py + + with open(fle, 'r') as f: + txt = f.read() + + """reprocess generate.py output""" + p = re.compile(r"[STHP][-]\d+\s*") + hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)") + source_dict = {} + hypothesis_dict = {} + score_dict = {} + target_dict = {} + pos_score_dict = {} + lines = txt.split("\n") + + for line in lines: + line += "\n" + prefix = re.search(p, line) + if prefix is not None: + assert len(prefix.group()) > 2, "prefix id not found" + _, j = prefix.span() + id_num = prefix.group()[2:] + id_num = int(id_num) + line_type = prefix.group()[0] + if line_type == "H": + h_txt = line[j:] + hypo = re.search(hp, h_txt) + assert hypo is not None, ("regular expression failed to find the hypothesis scoring") + _, i = hypo.span() + score = hypo.group() + if id_num in hypothesis_dict: + hypothesis_dict[id_num].append(h_txt[i:]) + score_dict[id_num].append(float(score)) + else: + hypothesis_dict[id_num] = [h_txt[i:]] + score_dict[id_num] = [float(score)] + + elif line_type == "S": + source_dict[id_num] = (line[j:]) + elif line_type == "T": + target_dict[id_num] = (line[j:]) + elif line_type == "P": + pos_scores = (line[j:]).split() + pos_scores = [float(x) for x in pos_scores] + if id_num in pos_score_dict: + pos_score_dict[id_num].append(pos_scores) + else: + pos_score_dict[id_num] = [pos_scores] + + return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict + + +def reprocess_nbest(fle): + """reprocess interactive.py output""" + with open(fle, 'r') as f: + txt = f.read() + + source_dict = {} + hypothesis_dict = {} + score_dict = {} + target_dict = {} + pos_score_dict = {} + lines = txt.split("\n") + + hp = re.compile(r'[-]?\d+[.]?\d+') + j = -1 + + for _i, line in enumerate(lines): + line += "\n" + line_type = line[0] + + if line_type == "H": + hypo = re.search(hp, line) + _, start_index = hypo.span() + score = hypo.group() + if j in score_dict: + score_dict[j].append(float(score)) + hypothesis_dict[j].append(line[start_index:].strip("\t")) + else: + score_dict[j] = [float(score)] + hypothesis_dict[j] = [line[start_index:].strip("\t")] + elif line_type == "O": + j += 1 + source_dict[j] = line[2:] + # we don't have the targets for interactive.py + target_dict[j] = "filler" + + elif line_type == "P": + pos_scores = [float(pos_score) for pos_score in line.split()[1:]] + if j in pos_score_dict: + pos_score_dict[j].append(pos_scores) + else: + pos_score_dict[j] = [pos_scores] + + assert source_dict.keys() == hypothesis_dict.keys() + assert source_dict.keys() == pos_score_dict.keys() + assert source_dict.keys() == score_dict.keys() + + return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict + + +def write_reprocessed(sources, hypos, targets, source_outfile, + hypo_outfile, target_outfile, right_to_left=False, + prefix_len=None, bpe_symbol=None, + target_prefix_frac=None, source_prefix_frac=None): + + """writes nbest hypothesis for rescoring""" + assert not (prefix_len is not None and target_prefix_frac is not None), \ + "in writing reprocessed, only one type of prefix may be used" + assert not (prefix_len is not None and source_prefix_frac is not None), \ + "in writing reprocessed, only one type of prefix may be used" + assert not (target_prefix_frac is not None and source_prefix_frac is not None), \ + "in writing reprocessed, only one type of prefix may be used" + + with open(source_outfile, 'w') as source_file, \ + open(hypo_outfile, 'w') as hypo_file, \ + open(target_outfile, 'w') as target_file: + + assert len(sources) == len(hypos), "sources and hypos list length mismatch" + if right_to_left: + for i in range(len(sources)): + for j in range(len(hypos[i])): + if prefix_len is None: + hypo_file.write(make_right_to_left(hypos[i][j])+"\n") + else: + raise NotImplementedError() + source_file.write(make_right_to_left(sources[i])+"\n") + target_file.write(make_right_to_left(targets[i])+"\n") + else: + for i in sorted(sources.keys()): + for j in range(len(hypos[i])): + if prefix_len is not None: + shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n" + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif target_prefix_frac is not None: + num_words, shortened, num_bpe_tokens = \ + calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol) + shortened += "\n" + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif source_prefix_frac is not None: + num_words, shortened, num_bpe_tokensn = \ + calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol) + shortened += "\n" + hypo_file.write(hypos[i][j]) + source_file.write(shortened) + target_file.write(targets[i]) + else: + hypo_file.write(hypos[i][j]) + source_file.write(sources[i]) + target_file.write(targets[i]) + + +def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol): + # return number of words, (not bpe tokens) that we want + no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol) + len_sen = len(no_bpe_sen.split()) + + num_words = math.ceil(len_sen * prefix_frac) + prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words) + num_bpe_tokens = len(prefix.split()) + return num_words, prefix, num_bpe_tokens + + +def get_prefix(sentence, prefix_len): + """assuming no bpe, gets the prefix of the sentence with prefix_len words""" + tokens = sentence.strip("\n").split() + if prefix_len >= len(tokens): + return sentence.strip("\n") + else: + return " ".join(tokens[:prefix_len]) + + +def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len): + if bpe_symbol is None: + return get_prefix(sentence, prefix_len) + else: + return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len)) + + +def get_prefix_from_len(sentence, bpe_symbol, prefix_len): + """get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens""" + bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]]) + if bpe_count == 0: + return sentence[:prefix_len] + else: + return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count) + + +def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len): + """given a prefix length in terms of words, return the number of bpe tokens""" + prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len) + assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len + return len(prefix.split(" ")) + + +def make_right_to_left(line): + tokens = line.split() + tokens.reverse() + new_line = " ".join(tokens) + return new_line + + +def remove_bpe(line, bpe_symbol): + line = line.replace("\n", '') + line = (line + ' ').replace(bpe_symbol, '').rstrip() + return line+("\n") + + +def remove_bpe_dict(pred_dict, bpe_symbol): + new_dict = {} + for i in pred_dict: + if type(pred_dict[i]) == list: + new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]] + new_dict[i] = new_list + else: + new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol) + return new_dict + + +def parse_bleu_scoring(line): + p = re.compile(r'(BLEU4 = )\d+[.]\d+') + res = re.search(p, line) + assert res is not None, line + return float(res.group()[8:]) + + +def get_full_from_prefix(hypo_prefix, hypos): + """given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix""" + for hypo in hypos: + hypo_prefix = hypo_prefix.strip("\n") + len_prefix = len(hypo_prefix) + if hypo[:len_prefix] == hypo_prefix: + return hypo + # no match found + raise Exception() + + +def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None, + lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False, + bitext2_backwards=False, normalize=False): + if bitext1_backwards: + bitext1_norm = src_len + else: + bitext1_norm = tgt_len + if bitext_score2 is not None: + if bitext2_backwards: + bitext2_norm = src_len + else: + bitext2_norm = tgt_len + else: + bitext2_norm = 1 + bitext_score2 = 0 + if normalize: + score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len + else: + score = a*bitext_score1 + b*bitext_score2+c*lm_score + + if lenpen is not None: + score /= (target_len) ** float(lenpen) + + return score + + +class BitextOutput(object): + def __init__(self, output_file, backwards, right_to_left, bpe_symbol, + prefix_len=None, target_prefix_frac=None, source_prefix_frac=None): + """process output from rescoring""" + source, hypo, score, target, pos_score = reprocess(output_file) + if backwards: + self.hypo_fracs = source_prefix_frac + else: + self.hypo_fracs = target_prefix_frac + + # remove length penalty so we can use raw scores + score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards) + source_lengths = {} + target_lengths = {} + + assert hypo.keys() == source.keys(), "key mismatch" + if backwards: + tmp = hypo + hypo = source + source = tmp + for i in source: + # since we are reranking, there should only be one hypo per source sentence + if backwards: + len_src = len(source[i][0].split()) + # record length without + if len_src == num_bpe_tokens[i][0] - 1: + source_lengths[i] = num_bpe_tokens[i][0] - 1 + else: + source_lengths[i] = num_bpe_tokens[i][0] + + target_lengths[i] = len(hypo[i].split()) + + source[i] = remove_bpe(source[i][0], bpe_symbol) + target[i] = remove_bpe(target[i], bpe_symbol) + hypo[i] = remove_bpe(hypo[i], bpe_symbol) + + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + + else: + len_tgt = len(hypo[i][0].split()) + # record length without + if len_tgt == num_bpe_tokens[i][0] - 1: + target_lengths[i] = num_bpe_tokens[i][0] - 1 + else: + target_lengths[i] = num_bpe_tokens[i][0] + + source_lengths[i] = len(source[i].split()) + + if right_to_left: + source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol) + target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol) + hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol) + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + else: + assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence" + source[i] = remove_bpe(source[i], bpe_symbol) + target[i] = remove_bpe(target[i], bpe_symbol) + hypo[i] = remove_bpe(hypo[i][0], bpe_symbol) + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + + self.rescore_source = source + self.rescore_hypo = hypo + self.rescore_score = score + self.rescore_target = target + self.rescore_pos_score = pos_score + self.backwards = backwards + self.right_to_left = right_to_left + self.target_lengths = target_lengths + self.source_lengths = source_lengths + + +class BitextOutputFromGen(object): + def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None): + if nbest: + pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file) + else: + pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file) + + assert len(pred_source) == len(pred_hypo) + assert len(pred_source) == len(pred_score) + assert len(pred_source) == len(pred_target) + assert len(pred_source) == len(pred_pos_score) + + # remove length penalty so we can use raw scores + pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo, + bpe_symbol, target_prefix_frac, False) + + self.source = pred_source + self.target = pred_target + self.score = pred_score + self.pos_score = pred_pos_score + self.hypo = pred_hypo + self.target_lengths = {} + self.source_lengths = {} + + self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol) + self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol) + self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol) + + # indexes to match those from the rescoring models + self.rescore_source = {} + self.rescore_target = {} + self.rescore_pos_score = {} + self.rescore_hypo = {} + self.rescore_score = {} + self.num_hypos = {} + self.backwards = False + self.right_to_left = False + + index = 0 + + for i in sorted(pred_source.keys()): + for j in range(len(pred_hypo[i])): + + self.target_lengths[index] = len(self.hypo[i][j].split()) + self.source_lengths[index] = len(self.source[i].split()) + + self.rescore_source[index] = self.no_bpe_source[i] + self.rescore_target[index] = self.no_bpe_target[i] + self.rescore_hypo[index] = self.no_bpe_hypo[i][j] + self.rescore_score[index] = float(pred_score[i][j]) + self.rescore_pos_score[index] = pred_pos_score[i][j] + self.num_hypos[index] = len(pred_hypo[i]) + index += 1 + + +def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards): + score_dict = {} + num_bpe_tokens_dict = {} + assert prefix_len is None or hypo_frac is None + for key in pos_score_dict: + score_dict[key] = [] + num_bpe_tokens_dict[key] = [] + for i in range(len(pos_score_dict[key])): + if prefix_len is not None and not backwards: + num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len) + score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens])) + num_bpe_tokens_dict[key].append(num_bpe_tokens) + elif hypo_frac is not None: + num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol) + score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len])) + num_bpe_tokens_dict[key].append(hypo_prefix_len) + else: + score_dict[key].append(sum(pos_score_dict[key][i])) + num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i])) + return score_dict, num_bpe_tokens_dict + + +class LMOutput(object): + def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): + lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \ + parse_lm(lm_score_file, prefix_len=prefix_len, + bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac) + + self.sentences = lm_sentences + self.score = lm_sen_scores + self.pos_score = lm_sen_pos_scores + self.lm_dict = lm_dict + self.no_bpe_sentences = lm_no_bpe_sentences + self.bpe_tokens = lm_bpe_tokens + + +def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): + """parse output of eval_lm""" + with open(input_file, 'r') as f: + text = f.readlines() + text = text[7:] + cleaned_text = text[:-2] + + sentences = {} + sen_scores = {} + sen_pos_scores = {} + no_bpe_sentences = {} + num_bpe_tokens_dict = {} + for _i, line in enumerate(cleaned_text): + tokens = line.split() + if tokens[0].isdigit(): + line_id = int(tokens[0]) + scores = [float(x[1:-1]) for x in tokens[2::2]] + sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n" + if bpe_symbol is not None: + # exclude symbol to match output from generate.py + bpe_sen = " ".join(tokens[1::2][:-1])+"\n" + no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol) + no_bpe_sentences[line_id] = no_bpe_sen + + if prefix_len is not None: + num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len) + sen_scores[line_id] = sum(scores[:num_bpe_tokens]) + num_bpe_tokens_dict[line_id] = num_bpe_tokens + elif target_prefix_frac is not None: + num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac, + bpe_symbol) + sen_scores[line_id] = sum(scores[:target_prefix_len]) + num_bpe_tokens_dict[line_id] = target_prefix_len + else: + sen_scores[line_id] = sum(scores) + num_bpe_tokens_dict[line_id] = len(scores) + + sen_pos_scores[line_id] = scores + + return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict + + +def get_directories(data_dir_name, num_rescore, gen_subset, + fw_name, shard_id, num_shards, + sampling=False, prefix_len=None, + target_prefix_frac=None, source_prefix_frac=None): + nbest_file_id = "nbest_" + str(num_rescore) + \ + "_subset_" + gen_subset + \ + "_fw_name_" + fw_name + \ + "_shard_" + str(shard_id) + \ + "_of_" + str(num_shards) + + if sampling: + nbest_file_id += "_sampling" + + # the directory containing all information for this nbest list + pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id + # the directory to store the preprocessed nbest list, for left to right rescoring + left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed" + if source_prefix_frac is not None: + left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac) + # the directory to store the preprocessed nbest list, for right to left rescoring + right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed" + # the directory to store the preprocessed nbest list, for backwards rescoring + backwards_preprocessed_dir = pre_gen+"/backwards" + if target_prefix_frac is not None: + backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac) + elif prefix_len is not None: + backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len) + + # the directory to store the preprocessed nbest list, for rescoring with P(T) + lm_preprocessed_dir = pre_gen+"/lm_preprocessed" + + return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ + backwards_preprocessed_dir, lm_preprocessed_dir + + +def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen, + cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code, + batch_size, lm_score_file, target_lang, source_lang, prefix_len=None): + if prefix_len is not None: + assert bpe_status == "different", "bpe status must be different to use prefix len" + if bpe_status == "no bpe": + # run lm on output without bpe + write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, + gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de", + pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe") + + preprocess_lm_param = ["--only-source", + "--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang, + "--srcdict", cur_lm_dict, + "--destdir", preprocess_directory] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [preprocess_directory, + "--path", cur_language_model, + "--output-word-probs", + "--batch-size", str(batch_size), + "--max-tokens", "1024", + "--sample-break-mode", "eos", + "--gen-subset", "train"] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, 'w') as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + elif bpe_status == "shared": + preprocess_lm_param = ["--only-source", + "--trainpref", pre_gen+"/rescore_data."+target_lang, + "--srcdict", cur_lm_dict, + "--destdir", preprocess_directory] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [preprocess_directory, + "--path", cur_language_model, + "--output-word-probs", + "--batch-size", str(batch_size), + "--sample-break-mode", "eos", + "--gen-subset", "train"] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, 'w') as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + elif bpe_status == "different": + rescore_file = pre_gen+"/rescore_data_no_bpe" + rescore_bpe = pre_gen+"/rescore_data_new_bpe" + + rescore_file += "." + rescore_bpe += "." + + write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, + gen_output.no_bpe_target, rescore_file+source_lang, + rescore_file+target_lang, pre_gen+"/reference_file_no_bpe", + bpe_symbol=None) + + # apply LM bpe to nbest list + bpe_src_param = ["-c", cur_lm_bpe_code, + "--input", rescore_file+target_lang, + "--output", rescore_bpe+target_lang] + subprocess.call(["python", + os.path.join(os.path.dirname(__file__), + "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, + shell=False) + # uncomment to use fastbpe instead of subword-nmt bpe + # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code] + # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False) + + preprocess_dir = preprocess_directory + + preprocess_lm_param = ["--only-source", + "--trainpref", rescore_bpe+target_lang, + "--srcdict", cur_lm_dict, + "--destdir", preprocess_dir] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [preprocess_dir, + "--path", cur_language_model, + "--output-word-probs", + "--batch-size", str(batch_size), + "--max-tokens", "1024", + "--sample-break-mode", "eos", + "--gen-subset", "train"] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, 'w') as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + +def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False, + target_prefix_frac=None, source_prefix_frac=None, backwards=None): + if lm_file: + score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt" + else: + score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt" + if backwards: + if prefix_len is not None: + score_file += "prefix_len"+str(prefix_len) + elif target_prefix_frac is not None: + score_file += "target_prefix_frac"+str(target_prefix_frac) + else: + if source_prefix_frac is not None: + score_file += "source_prefix_frac"+str(source_prefix_frac) + return score_file diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index c9ba53707..7fedc7755 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -498,7 +498,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): del state_dict[k] version_key = '{}.version'.format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False