mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-19 05:09:20 +03:00
f296824f40
Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/541 Just a combo of a stacked pair D14057943 & D14176011, Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo Differential Revision: D14251048 fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8
79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) 2017-present, Facebook, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the LICENSE file in
|
|
# the root directory of this source tree. An additional grant of patent rights
|
|
# can be found in the PATENTS file in the same directory.
|
|
"""
|
|
BLEU scoring of generated translations against reference translations.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
from fairseq import bleu
|
|
from fairseq.data import dictionary
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
|
|
# fmt: off
|
|
parser.add_argument('-s', '--sys', default='-', help='system output')
|
|
parser.add_argument('-r', '--ref', required=True, help='references')
|
|
parser.add_argument('-o', '--order', default=4, metavar='N',
|
|
type=int, help='consider ngrams up to this order')
|
|
parser.add_argument('--ignore-case', action='store_true',
|
|
help='case-insensitive scoring')
|
|
parser.add_argument('--sacrebleu', action='store_true',
|
|
help='score with sacrebleu')
|
|
# fmt: on
|
|
return parser
|
|
|
|
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
print(args)
|
|
|
|
assert args.sys == '-' or os.path.exists(args.sys), \
|
|
"System output file {} does not exist".format(args.sys)
|
|
assert os.path.exists(args.ref), \
|
|
"Reference file {} does not exist".format(args.ref)
|
|
|
|
dict = dictionary.Dictionary()
|
|
|
|
def readlines(fd):
|
|
for line in fd.readlines():
|
|
if args.ignore_case:
|
|
yield line.lower()
|
|
else:
|
|
yield line
|
|
|
|
if args.sacrebleu:
|
|
import sacrebleu
|
|
|
|
def score(fdsys):
|
|
with open(args.ref) as fdref:
|
|
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
|
|
else:
|
|
def score(fdsys):
|
|
with open(args.ref) as fdref:
|
|
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
|
|
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
|
|
sys_tok = dict.encode_line(sys_tok)
|
|
ref_tok = dict.encode_line(ref_tok)
|
|
scorer.add(ref_tok, sys_tok)
|
|
print(scorer.result_string(args.order))
|
|
|
|
if args.sys == '-':
|
|
score(sys.stdin)
|
|
else:
|
|
with open(args.sys, 'r') as f:
|
|
score(f)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|