Backward reranking public (#667)

Summary:
Implementation of noisy channel model reranking for release with paper
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/667

Reviewed By: michaelauli

Differential Revision: D15901665

Pulled By: nng555

fbshipit-source-id: 2de2c518be8e5828ffad72db3e741b0940623373
This commit is contained in:
Nathan Ng 2019-08-15 10:03:44 -07:00 committed by Facebook Github Bot
parent ac66df47b5
commit 49177c99c4
13 changed files with 1629 additions and 3 deletions

3
.gitignore vendored
View File

@ -116,3 +116,6 @@ fairseq/modules/*_layer/*_backward.cu
# data
data-bin/
# reranking
examples/reranking/rerank_data

View File

@ -146,8 +146,9 @@ def main(parsed_args):
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for hypos_i in hypos:
for i, hypos_i in enumerate(hypos):
hypo = hypos_i[0]
sample_id = sample['id'][i]
tokens = hypo['tokens']
tgt_len = tokens.numel()
@ -199,7 +200,8 @@ def main(parsed_args):
is_bpe = False
w = ''
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
print(str(int(sample_id)) + " " +
('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)))
wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})

10
examples/__init__.py Normal file
View File

@ -0,0 +1,10 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
__version__ = '0.7.2'
import examples.noisychannel # noqa

View File

@ -0,0 +1,72 @@
# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts.
## Citation:
```bibtex
@inproceedings{yee2018simple,
title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
author = {Kyra Yee and Yann Dauphin and Michael Auli},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2019},
}
```
## Pre-trained Models:
Model | Description | Download
---|---|---
`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2)
`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2)
`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2)
Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2)
## Example usage
```
mkdir rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example
beam=50
num_trials=1000
fw_name=fw_model_ex
bw_name=bw_model_ex
lm_name=lm_ex
data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe
data_dir_name=wmt17
lm=rerank_example/lm/checkpoint_best.pt
lm_bpe_code=rerank_example/lm/bpe32k.code
lm_dict=rerank_example/lm/dict.txt
batch_size=32
bw=rerank_example/backward_en2de.pt
fw=rerank_example/forward_de2en.pt
# reranking with P(T|S) P(S|T) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \
--lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \
--backwards1 --weight2 1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
# reranking with P(T|S) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \
--lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model1 $fw \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model1-name $fw_name --gen-model-name $fw_name
# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead.
python examples/noisychannel/rerank.py $data_dir \
--lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \
--data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
```

View File

@ -0,0 +1,8 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .rerank_options import *

View File

@ -0,0 +1,283 @@
import rerank_utils
import rerank_generate
import rerank_score_bw
import rerank_score_lm
from fairseq import bleu, options
from fairseq.data import dictionary
from examples.noisychannel import rerank_options
from multiprocessing import Pool
import math
import numpy as np
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize):
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
ordered_hypos = {}
ordered_targets = {}
for shard_id in range(len(bitext1_lst)):
bitext1 = bitext1_lst[shard_id]
bitext2 = bitext2_lst[shard_id]
gen_output = gen_output_lst[shard_id]
lm_res = lm_res_lst[shard_id]
total = len(bitext1.rescore_source.keys())
source_lst = []
hypo_lst = []
score_lst = []
reference_lst = []
j = 1
best_score = -math.inf
for i in range(total):
# length is measured in terms of words, not bpe tokens, since models may not share the same bpe
target_len = len(bitext1.rescore_hypo[i].split())
if lm_res is not None:
lm_score = lm_res.score[i]
else:
lm_score = 0
if bitext2 is not None:
bitext2_score = bitext2.rescore_score[i]
bitext2_backwards = bitext2.backwards
else:
bitext2_score = None
bitext2_backwards = None
score = rerank_utils.get_score(a, b, c, target_len,
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score,
lenpen=lenpen, src_len=bitext1.source_lengths[i],
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards,
bitext2_backwards=bitext2_backwards, normalize=normalize)
if score > best_score:
best_score = score
best_hypo = bitext1.rescore_hypo[i]
if j == gen_output.num_hypos[i] or j == args.num_rescore:
j = 1
hypo_lst.append(best_hypo)
score_lst.append(best_score)
source_lst.append(bitext1.rescore_source[i])
reference_lst.append(bitext1.rescore_target[i])
best_score = -math.inf
best_hypo = ""
else:
j += 1
gen_keys = list(sorted(gen_output.no_bpe_target.keys()))
for key in range(len(gen_keys)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
("pred and rescore hypo mismatch: i: " + str(key) + ", " + str(hypo_lst[key]) + str(gen_keys[key]) +
str(gen_output.no_bpe_hypo[key]))
sys_tok = dict.encode_line(hypo_lst[key])
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
sys_tok = dict.encode_line(full_hypo)
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
# if only one set of hyper parameters is provided, write the predictions to a file
if write_hypos:
# recover the orinal ids from n best list generation
for key in range(len(gen_output.no_bpe_target)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key])
ordered_hypos[gen_keys[key]] = hypo_lst[key]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
else:
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]])
ordered_hypos[gen_keys[key]] = full_hypo
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]]
# write the hypos in the original order from nbest list generation
if args.num_shards == (len(bitext1_lst)):
with open(target_outfile, 'w') as t:
with open(hypo_outfile, 'w') as h:
for key in range(len(ordered_hypos)):
t.write(ordered_targets[key])
h.write(ordered_hypos[key])
res = scorer.result_string(4)
if write_hypos:
print(res)
score = rerank_utils.parse_bleu_scoring(res)
return score
def match_target_hypo(args, target_outfile, hypo_outfile):
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
if len(args.weight1) == 1:
res = score_target_hypo(args, args.weight1[0], args.weight2[0],
args.weight3[0], args.lenpen[0], target_outfile,
hypo_outfile, True, args.normalize)
rerank_scores = [res]
else:
print("launching pool")
with Pool(32) as p:
rerank_scores = p.starmap(score_target_hypo,
[(args, args.weight1[i], args.weight2[i], args.weight3[i],
args.lenpen[i], target_outfile, hypo_outfile,
False, args.normalize) for i in range(len(args.weight1))])
if len(rerank_scores) > 1:
best_index = np.argmax(rerank_scores)
best_score = rerank_scores[best_index]
print("best score", best_score)
print("best lenpen", args.lenpen[best_index])
print("best weight1", args.weight1[best_index])
print("best weight2", args.weight2[best_index])
print("best weight3", args.weight3[best_index])
return args.lenpen[best_index], args.weight1[best_index], \
args.weight2[best_index], args.weight3[best_index], best_score
else:
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0]
def load_score_files(args):
if args.all_shards:
shard_ids = list(range(args.num_shards))
else:
shard_ids = [args.shard_id]
gen_output_lst = []
bitext1_lst = []
bitext2_lst = []
lm_res1_lst = []
for shard_id in shard_ids:
using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, shard_id, args.num_shards, args.sampling,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
# get gen output
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
nbest=using_nbest, prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac)
if rerank1_is_gen:
bitext1 = gen_output
else:
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1,
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
if args.score_model2 is not None or args.nbest_list is not None:
if rerank2_is_gen:
bitext2 = gen_output
else:
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2,
args.remove_bpe, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
assert bitext2.source_lengths == bitext1.source_lengths, \
"source lengths for rescoring models do not match"
assert bitext2.target_lengths == bitext1.target_lengths, \
"target lengths for rescoring models do not match"
else:
if args.diff_bpe:
assert args.score_model2 is None
bitext2 = gen_output
else:
bitext2 = None
if args.language_model is not None:
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len,
args.remove_bpe, args.target_prefix_frac)
else:
lm_res1 = None
gen_output_lst.append(gen_output)
bitext1_lst.append(bitext1)
bitext2_lst.append(bitext2)
lm_res1_lst.append(lm_res1)
return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst
def rerank(args):
if type(args.lenpen) is not list:
args.lenpen = [args.lenpen]
if type(args.weight1) is not list:
args.weight1 = [args.weight1]
if type(args.weight2) is not list:
args.weight2 = [args.weight2]
if type(args.weight3) is not list:
args.weight3 = [args.weight3]
if args.all_shards:
shard_ids = list(range(args.num_shards))
else:
shard_ids = [args.shard_id]
for shard_id in shard_ids:
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, shard_id, args.num_shards, args.sampling,
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac)
rerank_generate.gen_and_reprocess_nbest(args)
rerank_score_bw.score_bw(args)
rerank_score_lm.score_lm(args)
if args.write_hypos is None:
write_targets = pre_gen+"/matched_targets"
write_hypos = pre_gen+"/matched_hypos"
else:
write_targets = args.write_hypos+"_targets" + args.gen_subset
write_hypos = args.write_hypos+"_hypos" + args.gen_subset
if args.all_shards:
write_targets += "_all_shards"
write_hypos += "_all_shards"
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \
match_target_hypo(args, write_targets, write_hypos)
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
rerank(args)
if __name__ == '__main__':
cli_main()

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import rerank_utils
import os
import subprocess
from examples.noisychannel import rerank_options
from fairseq import options
import generate
import preprocess
from contextlib import redirect_stdout
"""
Generate n-best translations using a trained model.
"""
def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None:
args.score_dict_dir = args.data
if args.prefix_len is not None:
assert args.right_to_left1 is False, "prefix length not compatible with right to left models"
assert args.right_to_left2 is False, "prefix length not compatible with right to left models"
if args.nbest_list is not None:
assert args.score_model2 is None
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name
if not os.path.exists(store_data):
os.makedirs(store_data)
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported"
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported"
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \
"target prefix frac and target prefix len incompatible"
# make directory to store generation results
if not os.path.exists(pre_gen):
os.makedirs(pre_gen)
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
if args.nbest_list is not None:
rerank2_is_gen = True
# make directories to store preprossed nbest list for reranking
if not os.path.exists(left_to_right_preprocessed_dir):
os.makedirs(left_to_right_preprocessed_dir)
if not os.path.exists(right_to_left_preprocessed_dir):
os.makedirs(right_to_left_preprocessed_dir)
if not os.path.exists(lm_preprocessed_dir):
os.makedirs(lm_preprocessed_dir)
if not os.path.exists(backwards_preprocessed_dir):
os.makedirs(backwards_preprocessed_dir)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
using_nbest = args.nbest_list is not None
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
else:
if not os.path.isfile(predictions_bpe_file):
print("STEP 1: generate predictions using the p(T|S) model with bpe")
print(args.data)
param1 = [args.data,
"--path", args.gen_model,
"--shard-id", str(args.shard_id),
"--num-shards", str(args.num_shards),
"--nbest", str(args.num_rescore),
"--batch-size", str(args.batch_size),
"--beam", str(args.num_rescore),
"--max-sentences", str(args.num_rescore),
"--gen-subset", args.gen_subset,
"--source-lang", args.source_lang,
"--target-lang", args.target_lang]
if args.sampling:
param1 += ["--sampling"]
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, param1)
print(input_args)
with open(predictions_bpe_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe,
nbest=using_nbest, prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac)
if args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang,
pre_gen+"/target_gen_bpe."+args.target_lang,
pre_gen+"/reference_gen_bpe."+args.target_lang)
bitext_bpe = args.rescore_bpe_code
bpe_src_param = ["-c", bitext_bpe,
"--input", pre_gen+"/source_gen_bpe."+args.source_lang,
"--output", pre_gen+"/rescore_data."+args.source_lang]
bpe_tgt_param = ["-c", bitext_bpe,
"--input", pre_gen+"/target_gen_bpe."+args.target_lang,
"--output", pre_gen+"/rescore_data."+args.target_lang]
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
shell=False)
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param,
shell=False)
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen):
print("STEP 2: process the output of generate.py so we have clean text files with the translations")
rescore_file = "/rescore_data"
if args.prefix_len is not None:
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len)
if args.target_prefix_frac is not None:
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac)
if args.source_prefix_frac is not None:
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac)
if not args.right_to_left1 or not args.right_to_left2:
if not args.diff_bpe:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+rescore_file+"."+args.source_lang,
pre_gen+rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe)
if args.prefix_len is not None:
bw_rescore_file = prefix_len_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+prefix_len_rescore_file+"."+args.source_lang,
pre_gen+prefix_len_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", prefix_len=args.prefix_len,
bpe_symbol=args.remove_bpe)
elif args.target_prefix_frac is not None:
bw_rescore_file = target_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang,
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
target_prefix_frac=args.target_prefix_frac)
else:
bw_rescore_file = rescore_file
if args.source_prefix_frac is not None:
fw_rescore_file = source_prefix_frac_rescore_file
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang,
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang,
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe,
source_prefix_frac=args.source_prefix_frac)
else:
fw_rescore_file = rescore_file
if args.right_to_left1 or args.right_to_left2:
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target,
pre_gen+"/right_to_left_rescore_data."+args.source_lang,
pre_gen+"/right_to_left_rescore_data."+args.target_lang,
pre_gen+"/right_to_left_reference_file",
right_to_left=True, bpe_symbol=args.remove_bpe)
print("STEP 3: binarize the translations")
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen:
if args.backwards1 or args.backwards2:
if args.backwards_score_dict_dir is not None:
bw_dict = args.backwards_score_dict_dir
else:
bw_dict = args.score_dict_dir
bw_preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+bw_rescore_file,
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt",
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt",
"--destdir", backwards_preprocessed_dir]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(bw_preprocess_param)
preprocess.main(input_args)
preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+fw_rescore_file,
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
"--destdir", left_to_right_preprocessed_dir]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
if args.right_to_left1 or args.right_to_left2:
preprocess_param = ["--source-lang", scorer1_src,
"--target-lang", scorer1_tgt,
"--trainpref", pre_gen+"/right_to_left_rescore_data",
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt",
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt",
"--destdir", right_to_left_preprocessed_dir]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
return gen_output
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
gen_and_reprocess_nbest(args)
if __name__ == '__main__':
cli_main()

View File

@ -0,0 +1,128 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import options
def get_reranking_parser(default_task='translation'):
parser = options.get_parser('Generation and reranking', default_task)
add_reranking_args(parser)
return parser
def get_tuning_parser(default_task='translation'):
parser = options.get_parser('Reranking tuning', default_task)
add_reranking_args(parser)
add_tuning_args(parser)
return parser
def add_reranking_args(parser):
group = parser.add_argument_group("Reranking")
# fmt: off
group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True,
help='path to first model or ensemble of models for rescoring')
group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False,
help='path to second model or ensemble of models for rescoring')
group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10,
help='the number of candidate hypothesis to rescore')
group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128,
help='batch size for generating the nbest list')
group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'],
help='data subset to generate (train, valid, test)')
group.add_argument('--gen-model', default=None, metavar='FILE',
help='the model to generate translations')
group.add_argument('-b1', '--backwards1', action='store_true',
help='whether or not the first model group is backwards')
group.add_argument('-b2', '--backwards2', action='store_true',
help='whether or not the second model group is backwards')
group.add_argument('-a', '--weight1', default=1, nargs='+', type=float,
help='the weight(s) of the first model')
group.add_argument('-b', '--weight2', default=1, nargs='+', type=float,
help='the weight(s) of the second model, or the gen model if using nbest from interactive.py')
group.add_argument('-c', '--weight3', default=1, nargs='+', type=float,
help='the weight(s) of the third model')
# lm arguments
group.add_argument('-lm', '--language-model', default=None, metavar='FILE',
help='language model for target language to rescore translations')
group.add_argument('--lm-dict', default=None, metavar='FILE',
help='the dict of the language model for the target language')
group.add_argument('--lm-name', default=None,
help='the name of the language model for the target language')
group.add_argument('--lm-bpe-code', default=None, metavar='FILE',
help='the bpe code for the language model for the target language')
group.add_argument('--data-dir-name', default=None,
help='name of data directory')
group.add_argument('--lenpen', default=1, nargs='+', type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--score-dict-dir', default=None,
help='the directory with dictionaries for the scoring models')
group.add_argument('--right-to-left1', action='store_true',
help='whether the first model group is a right to left model')
group.add_argument('--right-to-left2', action='store_true',
help='whether the second model group is a right to left model')
group.add_argument('--remove-bpe', default='@@ ',
help='the bpe symbol, used for the bitext and LM')
group.add_argument('--prefix-len', default=None, type=int,
help='the length of the target prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--sampling', action='store_true',
help='use sampling instead of beam search for generating n best list')
group.add_argument('--diff-bpe', action='store_true',
help='bpe for rescoring and nbest list not the same')
group.add_argument('--rescore-bpe-code', default=None,
help='bpe code for rescoring models')
group.add_argument('--nbest-list', default=None,
help='use predefined nbest list in interactive.py format')
group.add_argument('--write-hypos', default=None,
help='filename prefix to write hypos to')
group.add_argument('--ref-translation', default=None,
help='reference translation to use with nbest list from interactive.py')
group.add_argument('--backwards-score-dict-dir', default=None,
help='the directory with dictionaries for the backwards model,'
'if None then it is assumed the fw and backwards models share dictionaries')
# extra scaling args
group.add_argument('--gen-model-name', default=None,
help='the name of the models that generated the nbest list')
group.add_argument('--model1-name', default=None,
help='the name of the set for model1 group ')
group.add_argument('--model2-name', default=None,
help='the name of the set for model2 group')
group.add_argument('--shard-id', default=0, type=int,
help='the id of the shard to generate')
group.add_argument('--num-shards', default=1, type=int,
help='the number of shards to generate across')
group.add_argument('--all-shards', action='store_true',
help='use all shards')
group.add_argument('--target-prefix-frac', default=None, type=float,
help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--source-prefix-frac', default=None, type=float,
help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--normalize', action='store_true',
help='whether to normalize by src and target len')
return group
def add_tuning_args(parser):
group = parser.add_argument_group("Tuning")
group.add_argument('--lower-bound', default=[-0.7], nargs='+', type=float,
help='lower bound of search space')
group.add_argument('--upper-bound', default=[3], nargs='+', type=float,
help='upper bound of search space')
group.add_argument('--tune-param', default=['lenpen'], nargs='+',
choices=['lenpen', 'weight1', 'weight2', 'weight3'],
help='the parameter(s) to tune')
group.add_argument('--tune-subset', default='valid', choices=['valid', 'test', 'train'],
help='the subset to tune on ')
group.add_argument('--num-trials', default=1000, type=int,
help='number of trials to do for random search')
group.add_argument('--share-weights', action='store_true',
help='share weight2 and weight 3')
return group

View File

@ -0,0 +1,95 @@
import rerank_utils
import os
from fairseq import options
from examples.noisychannel import rerank_options
from contextlib import redirect_stdout
import generate
def score_bw(args):
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
if args.score_model2 is not None:
if args.backwards2:
scorer2_src = args.target_lang
scorer2_tgt = args.source_lang
else:
scorer2_src = args.source_lang
scorer2_tgt = args.target_lang
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2)
if args.right_to_left1:
rerank_data1 = right_to_left_preprocessed_dir
elif args.backwards1:
rerank_data1 = backwards_preprocessed_dir
else:
rerank_data1 = left_to_right_preprocessed_dir
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
if not rerank1_is_gen and not os.path.isfile(score1_file):
print("STEP 4: score the translations for model 1")
model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
gen_model1_param = [rerank_data1] + gen_param + model_param1
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
with open(score1_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else:
rerank_data2 = left_to_right_preprocessed_dir
model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
gen_model2_param = [rerank_data2] + gen_param + model_param2
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
with open(score2_file, 'w') as f:
with redirect_stdout(f):
generate.main(input_args)
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
score_bw(args)
if __name__ == '__main__':
cli_main()

View File

@ -0,0 +1,48 @@
import rerank_utils
import os
from fairseq import options
from examples.noisychannel import rerank_options
def score_lm(args):
using_nbest = args.nbest_list is not None
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir = \
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
args.gen_model_name, args.shard_id, args.num_shards,
args.sampling, args.prefix_len, args.target_prefix_frac,
args.source_prefix_frac)
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True)
if args.language_model is not None and not os.path.isfile(lm_score_file):
print("STEP 4.5: language modeling for P(T)")
if args.lm_bpe_code is None:
bpe_status = "no bpe"
elif args.lm_bpe_code == "shared":
bpe_status = "shared"
else:
bpe_status = "different"
rerank_utils.lm_scoring(lm_preprocessed_dir, bpe_status, gen_output, pre_gen,
args.lm_dict, args.lm_name, args.language_model,
args.lm_bpe_code, 128, lm_score_file, args.target_lang,
args.source_lang, prefix_len=args.prefix_len)
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
score_lm(args)
if __name__ == '__main__':
cli_main()

View File

@ -0,0 +1,85 @@
import rerank
import argparse
import numpy as np
import random
from examples.noisychannel import rerank_options
from fairseq import options
def random_search(args):
param_values = []
tuneable_parameters = ['lenpen', 'weight1', 'weight2', 'weight3']
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
for i, elem in enumerate(initial_params):
if type(elem) is not list:
initial_params[i] = [elem]
else:
initial_params[i] = elem
tune_parameters = args.tune_param.copy()
for i in range(len(args.tune_param)):
assert args.upper_bound[i] >= args.lower_bound[i]
index = tuneable_parameters.index(args.tune_param[i])
del tuneable_parameters[index]
del initial_params[index]
tune_parameters += tuneable_parameters
param_values += initial_params
random.seed(args.seed)
random_params = np.array([[random.uniform(args.lower_bound[i], args.upper_bound[i])
for i in range(len(args.tune_param))]
for k in range(args.num_trials)])
set_params = np.array([[initial_params[i][0]
for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)])
random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy()
if args.nbest_list:
rerank_args['gen_subset'] = 'test'
else:
rerank_args['gen_subset'] = args.tune_subset
for k in range(len(tune_parameters)):
rerank_args[tune_parameters[k]] = list(random_params[:, k])
if args.share_weights:
k = tune_parameters.index('weight2')
rerank_args['weight3'] = list(random_params[:, k])
rerank_args = argparse.Namespace(**rerank_args)
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(rerank_args)
rerank_args = vars(args).copy()
rerank_args['lenpen'] = [best_lenpen]
rerank_args['weight1'] = [best_weight1]
rerank_args['weight2'] = [best_weight2]
rerank_args['weight3'] = [best_weight3]
# write the hypothesis from the valid set from the best trial
if args.gen_subset != "valid":
rerank_args['gen_subset'] = "valid"
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
# test with the best hyperparameters on gen subset
rerank_args = vars(args).copy()
rerank_args['gen_subset'] = args.gen_subset
rerank_args['lenpen'] = [best_lenpen]
rerank_args['weight1'] = [best_weight1]
rerank_args['weight2'] = [best_weight2]
rerank_args['weight3'] = [best_weight3]
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
def cli_main():
parser = rerank_options.get_tuning_parser()
args = options.parse_args_and_arch(parser)
random_search(args)
if __name__ == '__main__':
cli_main()

View File

@ -0,0 +1,646 @@
import subprocess
import os
import re
from fairseq import options
import eval_lm
import preprocess
from contextlib import redirect_stdout
import math
def reprocess(fle):
# takes in a file of generate.py translation generate_output
# returns a source dict and hypothesis dict, where keys are the ID num (as a string)
# and values and the corresponding source and translation. There may be several translations
# per source, so the values for hypothesis_dict are lists.
# parses output of generate.py
with open(fle, 'r') as f:
txt = f.read()
"""reprocess generate.py output"""
p = re.compile(r"[STHP][-]\d+\s*")
hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)")
source_dict = {}
hypothesis_dict = {}
score_dict = {}
target_dict = {}
pos_score_dict = {}
lines = txt.split("\n")
for line in lines:
line += "\n"
prefix = re.search(p, line)
if prefix is not None:
assert len(prefix.group()) > 2, "prefix id not found"
_, j = prefix.span()
id_num = prefix.group()[2:]
id_num = int(id_num)
line_type = prefix.group()[0]
if line_type == "H":
h_txt = line[j:]
hypo = re.search(hp, h_txt)
assert hypo is not None, ("regular expression failed to find the hypothesis scoring")
_, i = hypo.span()
score = hypo.group()
if id_num in hypothesis_dict:
hypothesis_dict[id_num].append(h_txt[i:])
score_dict[id_num].append(float(score))
else:
hypothesis_dict[id_num] = [h_txt[i:]]
score_dict[id_num] = [float(score)]
elif line_type == "S":
source_dict[id_num] = (line[j:])
elif line_type == "T":
target_dict[id_num] = (line[j:])
elif line_type == "P":
pos_scores = (line[j:]).split()
pos_scores = [float(x) for x in pos_scores]
if id_num in pos_score_dict:
pos_score_dict[id_num].append(pos_scores)
else:
pos_score_dict[id_num] = [pos_scores]
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def reprocess_nbest(fle):
"""reprocess interactive.py output"""
with open(fle, 'r') as f:
txt = f.read()
source_dict = {}
hypothesis_dict = {}
score_dict = {}
target_dict = {}
pos_score_dict = {}
lines = txt.split("\n")
hp = re.compile(r'[-]?\d+[.]?\d+')
j = -1
for _i, line in enumerate(lines):
line += "\n"
line_type = line[0]
if line_type == "H":
hypo = re.search(hp, line)
_, start_index = hypo.span()
score = hypo.group()
if j in score_dict:
score_dict[j].append(float(score))
hypothesis_dict[j].append(line[start_index:].strip("\t"))
else:
score_dict[j] = [float(score)]
hypothesis_dict[j] = [line[start_index:].strip("\t")]
elif line_type == "O":
j += 1
source_dict[j] = line[2:]
# we don't have the targets for interactive.py
target_dict[j] = "filler"
elif line_type == "P":
pos_scores = [float(pos_score) for pos_score in line.split()[1:]]
if j in pos_score_dict:
pos_score_dict[j].append(pos_scores)
else:
pos_score_dict[j] = [pos_scores]
assert source_dict.keys() == hypothesis_dict.keys()
assert source_dict.keys() == pos_score_dict.keys()
assert source_dict.keys() == score_dict.keys()
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def write_reprocessed(sources, hypos, targets, source_outfile,
hypo_outfile, target_outfile, right_to_left=False,
prefix_len=None, bpe_symbol=None,
target_prefix_frac=None, source_prefix_frac=None):
"""writes nbest hypothesis for rescoring"""
assert not (prefix_len is not None and target_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
assert not (prefix_len is not None and source_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
assert not (target_prefix_frac is not None and source_prefix_frac is not None), \
"in writing reprocessed, only one type of prefix may be used"
with open(source_outfile, 'w') as source_file, \
open(hypo_outfile, 'w') as hypo_file, \
open(target_outfile, 'w') as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left:
for i in range(len(sources)):
for j in range(len(hypos[i])):
if prefix_len is None:
hypo_file.write(make_right_to_left(hypos[i][j])+"\n")
else:
raise NotImplementedError()
source_file.write(make_right_to_left(sources[i])+"\n")
target_file.write(make_right_to_left(targets[i])+"\n")
else:
for i in sorted(sources.keys()):
for j in range(len(hypos[i])):
if prefix_len is not None:
shortened = get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)+"\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif target_prefix_frac is not None:
num_words, shortened, num_bpe_tokens = \
calc_length_from_frac(hypos[i][j], target_prefix_frac, bpe_symbol)
shortened += "\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif source_prefix_frac is not None:
num_words, shortened, num_bpe_tokensn = \
calc_length_from_frac(sources[i], source_prefix_frac, bpe_symbol)
shortened += "\n"
hypo_file.write(hypos[i][j])
source_file.write(shortened)
target_file.write(targets[i])
else:
hypo_file.write(hypos[i][j])
source_file.write(sources[i])
target_file.write(targets[i])
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
# return number of words, (not bpe tokens) that we want
no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol)
len_sen = len(no_bpe_sen.split())
num_words = math.ceil(len_sen * prefix_frac)
prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words)
num_bpe_tokens = len(prefix.split())
return num_words, prefix, num_bpe_tokens
def get_prefix(sentence, prefix_len):
"""assuming no bpe, gets the prefix of the sentence with prefix_len words"""
tokens = sentence.strip("\n").split()
if prefix_len >= len(tokens):
return sentence.strip("\n")
else:
return " ".join(tokens[:prefix_len])
def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len):
if bpe_symbol is None:
return get_prefix(sentence, prefix_len)
else:
return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len))
def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
"""get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens"""
bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]])
if bpe_count == 0:
return sentence[:prefix_len]
else:
return sentence[:prefix_len]+get_prefix_from_len(sentence[prefix_len:], bpe_symbol, bpe_count)
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
"""given a prefix length in terms of words, return the number of bpe tokens"""
prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len)
assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len
return len(prefix.split(" "))
def make_right_to_left(line):
tokens = line.split()
tokens.reverse()
new_line = " ".join(tokens)
return new_line
def remove_bpe(line, bpe_symbol):
line = line.replace("\n", '')
line = (line + ' ').replace(bpe_symbol, '').rstrip()
return line+("\n")
def remove_bpe_dict(pred_dict, bpe_symbol):
new_dict = {}
for i in pred_dict:
if type(pred_dict[i]) == list:
new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]]
new_dict[i] = new_list
else:
new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol)
return new_dict
def parse_bleu_scoring(line):
p = re.compile(r'(BLEU4 = )\d+[.]\d+')
res = re.search(p, line)
assert res is not None, line
return float(res.group()[8:])
def get_full_from_prefix(hypo_prefix, hypos):
"""given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix"""
for hypo in hypos:
hypo_prefix = hypo_prefix.strip("\n")
len_prefix = len(hypo_prefix)
if hypo[:len_prefix] == hypo_prefix:
return hypo
# no match found
raise Exception()
def get_score(a, b, c, target_len, bitext_score1, bitext_score2=None, lm_score=None,
lenpen=None, src_len=None, tgt_len=None, bitext1_backwards=False,
bitext2_backwards=False, normalize=False):
if bitext1_backwards:
bitext1_norm = src_len
else:
bitext1_norm = tgt_len
if bitext_score2 is not None:
if bitext2_backwards:
bitext2_norm = src_len
else:
bitext2_norm = tgt_len
else:
bitext2_norm = 1
bitext_score2 = 0
if normalize:
score = a*bitext_score1/bitext1_norm + b*bitext_score2/bitext2_norm+c*lm_score/src_len
else:
score = a*bitext_score1 + b*bitext_score2+c*lm_score
if lenpen is not None:
score /= (target_len) ** float(lenpen)
return score
class BitextOutput(object):
def __init__(self, output_file, backwards, right_to_left, bpe_symbol,
prefix_len=None, target_prefix_frac=None, source_prefix_frac=None):
"""process output from rescoring"""
source, hypo, score, target, pos_score = reprocess(output_file)
if backwards:
self.hypo_fracs = source_prefix_frac
else:
self.hypo_fracs = target_prefix_frac
# remove length penalty so we can use raw scores
score, num_bpe_tokens = get_score_from_pos(pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards)
source_lengths = {}
target_lengths = {}
assert hypo.keys() == source.keys(), "key mismatch"
if backwards:
tmp = hypo
hypo = source
source = tmp
for i in source:
# since we are reranking, there should only be one hypo per source sentence
if backwards:
len_src = len(source[i][0].split())
# record length without <eos>
if len_src == num_bpe_tokens[i][0] - 1:
source_lengths[i] = num_bpe_tokens[i][0] - 1
else:
source_lengths[i] = num_bpe_tokens[i][0]
target_lengths[i] = len(hypo[i].split())
source[i] = remove_bpe(source[i][0], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i], bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
len_tgt = len(hypo[i][0].split())
# record length without <eos>
if len_tgt == num_bpe_tokens[i][0] - 1:
target_lengths[i] = num_bpe_tokens[i][0] - 1
else:
target_lengths[i] = num_bpe_tokens[i][0]
source_lengths[i] = len(source[i].split())
if right_to_left:
source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol)
target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol)
hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
assert len(hypo[i]) == 1, "expected only one hypothesis per source sentence"
source[i] = remove_bpe(source[i], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
self.rescore_source = source
self.rescore_hypo = hypo
self.rescore_score = score
self.rescore_target = target
self.rescore_pos_score = pos_score
self.backwards = backwards
self.right_to_left = right_to_left
self.target_lengths = target_lengths
self.source_lengths = source_lengths
class BitextOutputFromGen(object):
def __init__(self, predictions_bpe_file, bpe_symbol=None, nbest=False, prefix_len=None, target_prefix_frac=None):
if nbest:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess_nbest(predictions_bpe_file)
else:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(predictions_bpe_file)
assert len(pred_source) == len(pred_hypo)
assert len(pred_source) == len(pred_score)
assert len(pred_source) == len(pred_target)
assert len(pred_source) == len(pred_pos_score)
# remove length penalty so we can use raw scores
pred_score, num_bpe_tokens = get_score_from_pos(pred_pos_score, prefix_len, pred_hypo,
bpe_symbol, target_prefix_frac, False)
self.source = pred_source
self.target = pred_target
self.score = pred_score
self.pos_score = pred_pos_score
self.hypo = pred_hypo
self.target_lengths = {}
self.source_lengths = {}
self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol)
self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol)
self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol)
# indexes to match those from the rescoring models
self.rescore_source = {}
self.rescore_target = {}
self.rescore_pos_score = {}
self.rescore_hypo = {}
self.rescore_score = {}
self.num_hypos = {}
self.backwards = False
self.right_to_left = False
index = 0
for i in sorted(pred_source.keys()):
for j in range(len(pred_hypo[i])):
self.target_lengths[index] = len(self.hypo[i][j].split())
self.source_lengths[index] = len(self.source[i].split())
self.rescore_source[index] = self.no_bpe_source[i]
self.rescore_target[index] = self.no_bpe_target[i]
self.rescore_hypo[index] = self.no_bpe_hypo[i][j]
self.rescore_score[index] = float(pred_score[i][j])
self.rescore_pos_score[index] = pred_pos_score[i][j]
self.num_hypos[index] = len(pred_hypo[i])
index += 1
def get_score_from_pos(pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards):
score_dict = {}
num_bpe_tokens_dict = {}
assert prefix_len is None or hypo_frac is None
for key in pos_score_dict:
score_dict[key] = []
num_bpe_tokens_dict[key] = []
for i in range(len(pos_score_dict[key])):
if prefix_len is not None and not backwards:
num_bpe_tokens = get_num_bpe_tokens_from_len(hypo_dict[key][i], bpe_symbol, prefix_len)
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
num_bpe_tokens_dict[key].append(num_bpe_tokens)
elif hypo_frac is not None:
num_words, shortened, hypo_prefix_len = calc_length_from_frac(hypo_dict[key][i], hypo_frac, bpe_symbol)
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
num_bpe_tokens_dict[key].append(hypo_prefix_len)
else:
score_dict[key].append(sum(pos_score_dict[key][i]))
num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i]))
return score_dict, num_bpe_tokens_dict
class LMOutput(object):
def __init__(self, lm_score_file, lm_dict=None, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
lm_sentences, lm_sen_scores, lm_sen_pos_scores, lm_no_bpe_sentences, lm_bpe_tokens = \
parse_lm(lm_score_file, prefix_len=prefix_len,
bpe_symbol=bpe_symbol, target_prefix_frac=target_prefix_frac)
self.sentences = lm_sentences
self.score = lm_sen_scores
self.pos_score = lm_sen_pos_scores
self.lm_dict = lm_dict
self.no_bpe_sentences = lm_no_bpe_sentences
self.bpe_tokens = lm_bpe_tokens
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
"""parse output of eval_lm"""
with open(input_file, 'r') as f:
text = f.readlines()
text = text[7:]
cleaned_text = text[:-2]
sentences = {}
sen_scores = {}
sen_pos_scores = {}
no_bpe_sentences = {}
num_bpe_tokens_dict = {}
for _i, line in enumerate(cleaned_text):
tokens = line.split()
if tokens[0].isdigit():
line_id = int(tokens[0])
scores = [float(x[1:-1]) for x in tokens[2::2]]
sentences[line_id] = " ".join(tokens[1::2][:-1])+"\n"
if bpe_symbol is not None:
# exclude <eos> symbol to match output from generate.py
bpe_sen = " ".join(tokens[1::2][:-1])+"\n"
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
no_bpe_sentences[line_id] = no_bpe_sen
if prefix_len is not None:
num_bpe_tokens = get_num_bpe_tokens_from_len(bpe_sen, bpe_symbol, prefix_len)
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
num_bpe_tokens_dict[line_id] = num_bpe_tokens
elif target_prefix_frac is not None:
num_words, shortened, target_prefix_len = calc_length_from_frac(bpe_sen, target_prefix_frac,
bpe_symbol)
sen_scores[line_id] = sum(scores[:target_prefix_len])
num_bpe_tokens_dict[line_id] = target_prefix_len
else:
sen_scores[line_id] = sum(scores)
num_bpe_tokens_dict[line_id] = len(scores)
sen_pos_scores[line_id] = scores
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
def get_directories(data_dir_name, num_rescore, gen_subset,
fw_name, shard_id, num_shards,
sampling=False, prefix_len=None,
target_prefix_frac=None, source_prefix_frac=None):
nbest_file_id = "nbest_" + str(num_rescore) + \
"_subset_" + gen_subset + \
"_fw_name_" + fw_name + \
"_shard_" + str(shard_id) + \
"_of_" + str(num_shards)
if sampling:
nbest_file_id += "_sampling"
# the directory containing all information for this nbest list
pre_gen = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+data_dir_name+"/"+nbest_file_id
# the directory to store the preprocessed nbest list, for left to right rescoring
left_to_right_preprocessed_dir = pre_gen+"/left_to_right_preprocessed"
if source_prefix_frac is not None:
left_to_right_preprocessed_dir = left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
# the directory to store the preprocessed nbest list, for right to left rescoring
right_to_left_preprocessed_dir = pre_gen+"/right_to_left_preprocessed"
# the directory to store the preprocessed nbest list, for backwards rescoring
backwards_preprocessed_dir = pre_gen+"/backwards"
if target_prefix_frac is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_frac"+str(target_prefix_frac)
elif prefix_len is not None:
backwards_preprocessed_dir = backwards_preprocessed_dir+"/prefix_"+str(prefix_len)
# the directory to store the preprocessed nbest list, for rescoring with P(T)
lm_preprocessed_dir = pre_gen+"/lm_preprocessed"
return pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
backwards_preprocessed_dir, lm_preprocessed_dir
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
if prefix_len is not None:
assert bpe_status == "different", "bpe status must be different to use prefix len"
if bpe_status == "no bpe":
# run lm on output without bpe
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")
preprocess_lm_param = ["--only-source",
"--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_directory]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_directory,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--max-tokens", "1024",
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "shared":
preprocess_lm_param = ["--only-source",
"--trainpref", pre_gen+"/rescore_data."+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_directory]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_directory,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "different":
rescore_file = pre_gen+"/rescore_data_no_bpe"
rescore_bpe = pre_gen+"/rescore_data_new_bpe"
rescore_file += "."
rescore_bpe += "."
write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
gen_output.no_bpe_target, rescore_file+source_lang,
rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
bpe_symbol=None)
# apply LM bpe to nbest list
bpe_src_param = ["-c", cur_lm_bpe_code,
"--input", rescore_file+target_lang,
"--output", rescore_bpe+target_lang]
subprocess.call(["python",
os.path.join(os.path.dirname(__file__),
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
shell=False)
# uncomment to use fastbpe instead of subword-nmt bpe
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
preprocess_dir = preprocess_directory
preprocess_lm_param = ["--only-source",
"--trainpref", rescore_bpe+target_lang,
"--srcdict", cur_lm_dict,
"--destdir", preprocess_dir]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [preprocess_dir,
"--path", cur_language_model,
"--output-word-probs",
"--batch-size", str(batch_size),
"--max-tokens", "1024",
"--sample-break-mode", "eos",
"--gen-subset", "train"]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, 'w') as f:
with redirect_stdout(f):
eval_lm.main(input_args)
def rescore_file_name(nbest_dir, prefix_len, scorer_name, lm_file=False,
target_prefix_frac=None, source_prefix_frac=None, backwards=None):
if lm_file:
score_file = nbest_dir+"/lm_score_translations_model_"+scorer_name+".txt"
else:
score_file = nbest_dir+"/"+scorer_name+"_score_translations.txt"
if backwards:
if prefix_len is not None:
score_file += "prefix_len"+str(prefix_len)
elif target_prefix_frac is not None:
score_file += "target_prefix_frac"+str(target_prefix_frac)
else:
if source_prefix_frac is not None:
score_file += "source_prefix_frac"+str(source_prefix_frac)
return score_file

View File

@ -498,7 +498,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
del state_dict[k]
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False