mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-23 08:29:06 +03:00
1290ace1e1
args.separator was being passed to merges instead of separator.
126 lines
4.6 KiB
Python
Executable File
126 lines
4.6 KiB
Python
Executable File
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# Author: Rico Sennrich
|
|
|
|
"""Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text.
|
|
This script learns BPE jointly on a concatenation of a list of texts (typically the source and target side of a parallel corpus,
|
|
applies the learned operation to each and (optionally) returns the resulting vocabulary of each text.
|
|
The vocabulary can be used in apply_bpe.py to avoid producing symbols that are rare or OOV in a training text.
|
|
|
|
Reference:
|
|
Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units.
|
|
Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany.
|
|
"""
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import sys
|
|
import os
|
|
import codecs
|
|
import argparse
|
|
import tempfile
|
|
from collections import Counter
|
|
|
|
import learn_bpe
|
|
import apply_bpe
|
|
|
|
# hack for python2/3 compatibility
|
|
from io import open
|
|
argparse.open = open
|
|
|
|
def create_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="learn BPE-based word segmentation")
|
|
|
|
parser.add_argument(
|
|
'--input', '-i', type=argparse.FileType('r'), required=True, nargs = '+',
|
|
metavar='PATH',
|
|
help="Input texts (multiple allowed).")
|
|
parser.add_argument(
|
|
'--output', '-o', type=argparse.FileType('w'), required=True,
|
|
metavar='PATH',
|
|
help="Output file for BPE codes.")
|
|
parser.add_argument(
|
|
'--symbols', '-s', type=int, default=10000,
|
|
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
|
parser.add_argument(
|
|
'--separator', type=str, default='@@', metavar='STR',
|
|
help="Separator between non-final subword units (default: '%(default)s'))")
|
|
parser.add_argument(
|
|
'--write-vocabulary', type=argparse.FileType('w'), nargs = '+', default=None,
|
|
metavar='PATH', dest='vocab',
|
|
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
|
|
parser.add_argument(
|
|
'--min-frequency', type=int, default=2, metavar='FREQ',
|
|
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
|
parser.add_argument(
|
|
'--verbose', '-v', action="store_true",
|
|
help="verbose mode.")
|
|
|
|
return parser
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# python 2/3 compatibility
|
|
if sys.version_info < (3, 0):
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
|
else:
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
|
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
|
|
if args.vocab and len(args.input) != len(args.vocab):
|
|
sys.stderr.write('Error: number of input files and vocabulary files must match\n')
|
|
sys.exit(1)
|
|
|
|
# read/write files as UTF-8
|
|
args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
|
|
args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab]
|
|
|
|
# get combined vocabulary of all input texts
|
|
full_vocab = Counter()
|
|
for f in args.input:
|
|
full_vocab += learn_bpe.get_vocabulary(f)
|
|
f.seek(0)
|
|
|
|
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
|
|
|
# learn BPE on combined vocabulary
|
|
with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
|
|
learn_bpe.main(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True)
|
|
|
|
with codecs.open(args.output.name, encoding='UTF-8') as codes:
|
|
bpe = apply_bpe.BPE(codes, separator=args.separator)
|
|
|
|
# apply BPE to each training corpus and get vocabulary
|
|
for train_file, vocab_file in zip(args.input, args.vocab):
|
|
|
|
tmp = tempfile.NamedTemporaryFile(delete=False)
|
|
tmp.close()
|
|
|
|
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
|
|
|
train_file.seek(0)
|
|
for line in train_file:
|
|
tmpout.write(bpe.segment(line).strip())
|
|
tmpout.write('\n')
|
|
|
|
tmpout.close()
|
|
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
|
|
|
vocab = learn_bpe.get_vocabulary(tmpin)
|
|
tmpin.close()
|
|
os.remove(tmp.name)
|
|
|
|
for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True):
|
|
vocab_file.write("{0} {1}\n".format(key, freq))
|
|
vocab_file.close()
|