byte-level BPE. Still not fully tested, fails some unit tests with glossaries.

This commit is contained in:
Rico Sennrich 2024-07-24 14:12:50 +02:00
parent 93f0c93ccd
commit 0607b5443f
7 changed files with 281 additions and 176 deletions

View File

@ -1,5 +1,8 @@
CHANGELOG
---------
v0.3.9
- byte-level BPE support
- remove support for Python 2
v0.3.8:
- multiprocessing support (get_vocab and apply_bpe)

View File

@ -11,7 +11,7 @@ def test_suite():
setup(
name='subword_nmt',
version='0.3.8',
version='0.3.9',
description='Unsupervised Word Segmentation for Neural Machine Translation and Text Generation',
long_description=(codecs.open("README.md", encoding='utf-8').read() +
"\n\n" + codecs.open("CHANGELOG.md", encoding='utf-8').read()),
@ -25,7 +25,6 @@ setup(
'Topic :: Text Processing',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 3',
],
install_requires=['mock',

View File

@ -24,35 +24,44 @@ import warnings
import random
import tempfile
from multiprocessing import Pool, cpu_count
# hack for python2/3 compatibility
from io import open
argparse.open = open
from contextlib import contextmanager
class BPE(object):
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None, is_bytes=False):
codes.seek(0)
offset=1
# check version information
firstline = codes.readline()
if firstline.startswith('#version:'):
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
if is_bytes:
firstline = codes.readline()
self.version = (0, 2)
offset += 1
else:
self.version = (0, 1)
codes.seek(0)
firstline = codes.readline()
if firstline.startswith('#version:'):
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
offset += 1
else:
self.version = (0, 1)
codes.seek(0)
self.bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes.read().rstrip('\n').split('\n')) if (n < merges or merges == -1)]
self.strip_chars = b'\r\n ' if is_bytes else '\r\n '
self.newline_char = b'\n' if is_bytes else '\n'
self.split_char = b' ' if is_bytes else ' '
self.bpe_codes = [tuple(item.strip(self.strip_chars).split(self.split_char)) for (n, item) in enumerate(codes.read().rstrip(self.newline_char).split(self.newline_char)) if (n < merges or merges == -1)]
for i, item in enumerate(self.bpe_codes):
if len(item) != 2:
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, ' '.join(item)))
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, self.split_char.join(item)))
sys.stderr.write('The line should exist of exactly two subword units, separated by whitespace\n')
sys.exit(1)
self.is_bytes = is_bytes
# some hacking to deal with duplicates (only consider first instance)
self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
@ -62,22 +71,30 @@ class BPE(object):
self.vocab = vocab
self.glossaries = glossaries if glossaries else []
if glossaries:
if is_bytes:
glossaries = [item.encode('utf-8') for item in glossaries]
self.glossaries_regex = re.compile(b'^(' + b'|'.join(glossaries) + b')$')
else:
self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries)))
else:
self.glossaries_regex = None
self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries))) if glossaries else None
self.glossaries = glossaries if glossaries else []
self.cache = {}
def process_lines(self, filename, outfile, dropout=0, num_workers=1):
if sys.version_info < (3, 0):
print("Parallel mode is only supported in Python3.")
if sys.version_info < (3, 0) :
print("Parallel mode is only supported in Python3")
sys.exit(1)
if num_workers == 1:
_process_lines(self, filename, outfile, dropout, 0, 0)
elif num_workers > 1:
with open(filename, encoding="utf-8") as f:
mode = 'rb' if self.is_bytes else 'r'
with open_file(filename, mode) as f:
size = os.fstat(f.fileno()).st_size
chunk_size = int(size / num_workers)
offsets = [0 for _ in range(num_workers + 1)]
@ -103,7 +120,7 @@ class BPE(object):
pool.close()
pool.join()
for i in range(num_workers):
with open(res_files[i].name, encoding="utf-8") as fi:
with open_file(res_files[i].name, mode) as fi:
for line in fi:
outfile.write(line)
os.remove(res_files[i].name)
@ -113,15 +130,15 @@ class BPE(object):
def process_line(self, line, dropout=0):
"""segment line, dealing with leading and trailing whitespace"""
out = ""
out = b"" if self.is_bytes else ""
leading_whitespace = len(line)-len(line.lstrip('\r\n '))
leading_whitespace = len(line)-len(line.lstrip(self.strip_chars))
if leading_whitespace:
out += line[:leading_whitespace]
out += self.segment(line, dropout)
trailing_whitespace = len(line)-len(line.rstrip('\r\n '))
trailing_whitespace = len(line)-len(line.rstrip(self.strip_chars))
if trailing_whitespace and trailing_whitespace != len(line):
out += line[-trailing_whitespace:]
@ -129,8 +146,8 @@ class BPE(object):
def segment(self, sentence, dropout=0):
"""segment single sentence (whitespace-tokenized string) with BPE encoding"""
segments = self.segment_tokens(sentence.strip('\r\n ').split(' '), dropout)
return ' '.join(segments)
segments = self.segment_tokens(sentence.strip(self.strip_chars).split(self.split_char), dropout)
return self.split_char.join(segments)
def segment_tokens(self, tokens, dropout=0):
"""segment a sequence of tokens with BPE encoding"""
@ -148,6 +165,7 @@ class BPE(object):
self.version,
self.cache,
self.glossaries_regex,
self.is_bytes,
dropout)]
for item in new_word[:-1]:
@ -160,15 +178,19 @@ class BPE(object):
word_segments = [word]
for gloss in self.glossaries:
word_segments = [out_segments for segment in word_segments
for out_segments in isolate_glossary(segment, gloss)]
for out_segments in isolate_glossary(segment, gloss, self.is_bytes)]
return word_segments
def _process_lines(bpe, filename, outfile, dropout, begin, end):
write_mode = 'wb' if bpe.is_bytes else 'w'
read_mode = 'rb' if bpe.is_bytes else 'r'
if isinstance(outfile, str):
fo = open(outfile, "w", encoding="utf-8")
fo = open_file(outfile, write_mode)
else:
fo = outfile
with open(filename, encoding="utf-8") as f:
with open_file(filename, read_mode) as f:
f.seek(begin)
line = f.readline()
while line:
@ -181,6 +203,17 @@ def _process_lines(bpe, filename, outfile, dropout, begin, end):
if isinstance(outfile, str):
fo.close()
@contextmanager
def open_file(filename, mode):
if mode in ('r', 'w'):
f = open(filename, mode, encoding="utf-8")
elif mode in ('rb', 'wb'):
f = open(filename, mode)
try:
yield f
finally:
f.close()
def create_parser(subparsers=None):
if subparsers:
@ -193,11 +226,11 @@ def create_parser(subparsers=None):
description="learn BPE-based word segmentation")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
'--input', '-i', type=argparse.FileType('rb'), default=sys.stdin,
metavar='PATH',
help="Input file (default: standard input).")
parser.add_argument(
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
'--codes', '-c', type=argparse.FileType('rb'), metavar='PATH',
required=True,
help="File with BPE codes (created by learn_bpe.py).")
parser.add_argument(
@ -206,14 +239,14 @@ def create_parser(subparsers=None):
help="Use this many BPE operations (<= number of learned symbols)"+
"default: Apply all the learned merge operations")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
'--output', '-o', type=argparse.FileType('wb'), default=sys.stdout,
metavar='PATH',
help="Output file (default: standard output)")
parser.add_argument(
'--separator', '-s', type=str, default='@@', metavar='STR',
'--separator', '-s', type=bytes, default=b'@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
parser.add_argument(
'--vocabulary', type=argparse.FileType('r'), default=None,
'--vocabulary', type=argparse.FileType('rb'), default=None,
metavar="PATH",
help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.")
parser.add_argument(
@ -240,7 +273,7 @@ def create_parser(subparsers=None):
return parser
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0):
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, is_bytes=False, dropout=0):
"""Encode word based on list of BPE merge operations, which are applied consecutively
"""
@ -252,12 +285,16 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
return (orig,)
if len(orig) == 1:
return orig
return (orig,)
if version == (0, 1):
eow = b'</w>' if is_bytes else '</w>'
if is_bytes:
word = list(map(lambda b: bytes([b]), orig[:-1])) + [orig[-1:] + eow]
elif version == (0, 1):
word = list(orig) + ['</w>']
elif version == (0, 2): # more consistent handling of word-final segments
word = list(orig[:-1]) + [orig[-1] + '</w>']
word = list(orig[:-1]) + [orig[-1] + eow]
else:
raise NotImplementedError
@ -277,7 +314,10 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
i = 0
new_word = []
bigram = ''.join(bigram)
if is_bytes:
bigram = b''.join(bigram)
else:
bigram = ''.join(bigram)
for j in positions:
# merges are invalid if they start before current position. This can happen if there are overlapping pairs: (x x x -> xx x)
if j < i:
@ -289,9 +329,9 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
word = new_word
# don't print end-of-word symbols
if word[-1] == '</w>':
if word[-1] == eow:
word = word[:-1]
elif word[-1].endswith('</w>'):
elif word[-1].endswith(eow):
word[-1] = word[-1][:-4]
word = tuple(word)
@ -367,7 +407,7 @@ def read_vocabulary(vocab_file, threshold):
return vocabulary
def isolate_glossary(word, glossary):
def isolate_glossary(word, glossary, is_bytes=False):
"""
Isolate a glossary present inside a word.
@ -376,14 +416,34 @@ def isolate_glossary(word, glossary):
For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is:
['1934', 'USA', 'B', 'USA']
"""
if is_bytes:
pattern = b'^'+glossary+b'$'
else:
pattern = '^'+glossary+'$'
strip_chars = b'\r\n ' if is_bytes else '\r\n '
empty_string = b'' if is_bytes else ''
# regex equivalent of (if word == glossary or glossary not in word)
if re.match('^'+glossary+'$', word) or not re.search(glossary, word):
if re.match(pattern, word) or not re.search(glossary, word):
return [word]
else:
segments = re.split(r'({})'.format(glossary), word)
segments, ending = segments[:-1], segments[-1]
if is_bytes:
segments = re.split(rb'(' + glossary + rb')', word)
segments, ending = segments[:-1], segments[-1:]
else:
segments = re.split(r'({})'.format(glossary), word)
segments, ending = segments[:-1], segments[-1]
segments = list(filter(None, segments)) # Remove empty strings in regex group.
return segments + [ending.strip('\r\n ')] if ending != '' else segments
return segments + [ending[0].strip(strip_chars)] if ending != empty_string else segments
# first line of BPE code file indicates if it is byte-level or UTF-8
def get_byte_mode(code_file_name):
firstline = open(code_file_name, mode='rb').readline()
if firstline.endswith(b'byte\n'):
return True
else:
return False
if __name__ == '__main__':
@ -397,9 +457,8 @@ if __name__ == '__main__':
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
print("Python 2 is deprecated. Use Python 3")
sys.exit(1)
else:
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
@ -411,33 +470,37 @@ if __name__ == '__main__':
if args.num_workers <= 0:
args.num_workers = cpu_count()
# read/write files as UTF-8
# check if codes are bytes or UTF-8
is_bytes = get_byte_mode(args.codes.name)
args.separator = args.separator.decode('UTF-8') if not is_bytes else args.separator
# read/write files as bytes or UTF-8, depending on mode
if is_bytes:
if args.input.name == '<stdin>':
args.input = sys.stdin.buffer
if args.output.name == '<stdout>':
args.output = sys.stdout.buffer
else:
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
if args.vocabulary:
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
if args.vocabulary:
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
if args.vocabulary:
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
else:
vocabulary = None
if sys.version_info < (3, 0):
args.separator = args.separator.decode('UTF-8')
if args.glossaries:
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
if args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
if args.seed is not None:
random.seed(args.seed)
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries, is_bytes)
if args.input.name == '<stdin>' or args.num_workers == 1:
if args.num_workers > 1:

View File

@ -31,10 +31,6 @@ except ImportError:
def tqdm(iterator, *args, **kwargs):
return iterator
# hack for python2/3 compatibility
from io import open
argparse.open = open
def create_parser(subparsers=None):
if subparsers:
@ -47,12 +43,11 @@ def create_parser(subparsers=None):
description="learn BPE-based word segmentation")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
'--input', '-i', type=argparse.FileType('rb'), default=sys.stdin,
metavar='PATH',
help="Input text (default: standard input).")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
'--output', '-o', type=argparse.FileType('wb'), default=sys.stdout,
metavar='PATH',
help="Output file for BPE codes (default: standard output)")
parser.add_argument(
@ -61,6 +56,9 @@ def create_parser(subparsers=None):
parser.add_argument(
'--min-frequency', type=int, default=2, metavar='FREQ',
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
parser.add_argument(
'--byte', '-b', action="store_true",
help="byte-level BPE.")
parser.add_argument('--dict-input', action="store_true",
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
parser.add_argument(
@ -75,14 +73,18 @@ def create_parser(subparsers=None):
return parser
def get_vocabulary(fobj, is_dict=False, num_workers=1):
def get_vocabulary(fobj, is_dict=False, is_bytes=False, num_workers=1):
"""Read text and return dictionary that encodes vocabulary
"""
vocab = Counter()
strip_chars = b'\r\n ' if is_bytes else '\r\n '
split_char = b' ' if is_bytes else ' '
if is_dict:
for i, line in enumerate(fobj):
try:
word, count = line.strip('\r\n ').split(' ')
word, count = line.strip(strip_chars).split(split_char)
except:
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
sys.exit(1)
@ -91,13 +93,13 @@ def get_vocabulary(fobj, is_dict=False, num_workers=1):
if num_workers > 1:
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
for i, line in enumerate(fobj):
for word in line.strip('\r\n ').split(' '):
for word in line.strip(strip_chars).split(split_char):
if word:
vocab[word] += 1
elif num_workers > 1:
if sys.version_info < (3, 0):
print("Parallel mode is only supported in Python3.")
if is_bytes:
print('byte-level BPE not yet ported to parallel mode')
sys.exit(1)
with open(fobj.name, encoding="utf8") as f:
@ -231,13 +233,23 @@ def get_pair_statistics(vocab):
return stats, indices
def replace_pair(pair, vocab, indices):
def replace_pair(pair, vocab, indices, is_bytes=False):
"""Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
split_char = b' ' if is_bytes else ' '
first, second = pair
pair_str = ''.join(pair)
pair_str = pair_str.replace('\\','\\\\')
if is_bytes:
pair_str = b''.join(pair)
# first = bytes((first,))
# second = bytes((second,))
pair_str = pair_str.replace(b'\\',b'\\\\')
pattern = re.compile(rb'(?<!\S)' + re.escape(first + split_char + second) + rb'(?!\S)')
else:
pair_str = ''.join(pair)
pair_str = pair_str.replace('\\','\\\\')
pattern = re.compile(r'(?<!\S)' + re.escape(first + split_char + second) + r'(?!\S)')
changes = []
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
if sys.version_info < (3, 0):
iterator = indices[pair].iteritems()
else:
@ -246,9 +258,9 @@ def replace_pair(pair, vocab, indices):
if freq < 1:
continue
word, freq = vocab[j]
new_word = ' '.join(word)
new_word = split_char.join(word)
new_word = pattern.sub(pair_str, new_word)
new_word = tuple(new_word.split(' '))
new_word = tuple(new_word.split(split_char))
vocab[j] = (new_word, freq)
changes.append((j, new_word, word, freq))
@ -271,16 +283,22 @@ def prune_stats(stats, big_stats, threshold):
big_stats[item] = freq
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False, num_workers=1):
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, is_bytes=False, total_symbols=False, num_workers=1):
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
"""
# version 0.2 changes the handling of the end-of-word token ('</w>');
# version numbering allows bckward compatibility
outfile.write('#version: 0.2\n')
if is_bytes:
outfile.write(b'#version: 0.2 byte\n')
else:
outfile.write('#version: 0.2\n')
vocab = get_vocabulary(infile, is_dict, num_workers)
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
vocab = get_vocabulary(infile, is_dict, is_bytes, num_workers)
if is_bytes:
vocab = dict([(tuple(map(lambda b: bytes([b]), x[:-1]))+(x[-1:]+b'</w>',) ,y) for (x,y) in vocab.items()])
else:
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
stats, indices = get_pair_statistics(sorted_vocab)
@ -319,8 +337,11 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
if verbose:
sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
outfile.write('{0} {1}\n'.format(*most_frequent))
changes = replace_pair(most_frequent, sorted_vocab, indices)
if is_bytes:
outfile.write(most_frequent[0] + b' ' + most_frequent[1] + b'\n')
else:
outfile.write('{0} {1}\n'.format(*most_frequent))
changes = replace_pair(most_frequent, sorted_vocab, indices, is_bytes)
update_pair_statistics(most_frequent, changes, stats, indices)
stats[most_frequent] = 0
if not i % 100:
@ -337,33 +358,34 @@ if __name__ == '__main__':
DeprecationWarning
)
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
else:
parser = create_parser()
args = parser.parse_args()
if not args.byte:
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
parser = create_parser()
args = parser.parse_args()
if args.num_workers <= 0:
args.num_workers = cpu_count()
if sys.version_info < (3, 0) and args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
if sys.version_info < (3, 0):
print("Python 2 is deprecated. Use Python 3")
sys.exit(1)
# read/write files as UTF-8
if args.input.name != '<stdin>':
if args.input.name != '<stdin>' and not args.byte:
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
if args.output.name != '<stdout>' and not args.byte:
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols, num_workers=args.num_workers)
if args.byte:
if args.input.name == '<stdin>':
args.input = sys.stdin.buffer
if args.output.name == '<stdout>':
args.output = sys.stdout.buffer
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, is_bytes=args.byte, total_symbols=args.total_symbols, num_workers=args.num_workers)
# close files
if args.input.name != '<stdin>':

View File

@ -32,10 +32,6 @@ else:
from . import learn_bpe
from . import apply_bpe
# hack for python2/3 compatibility
from io import open
argparse.open = open
def create_parser(subparsers=None):
if subparsers:
@ -48,21 +44,24 @@ def create_parser(subparsers=None):
description="learn BPE-based word segmentation")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), required=True, nargs = '+',
'--input', '-i', type=argparse.FileType('rb'), required=True, nargs = '+',
metavar='PATH',
help="Input texts (multiple allowed).")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), required=True,
'--output', '-o', type=argparse.FileType('wb'), required=True,
metavar='PATH',
help="Output file for BPE codes.")
parser.add_argument(
'--symbols', '-s', type=int, default=10000,
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
parser.add_argument(
'--separator', type=str, default='@@', metavar='STR',
'--byte', '-b', action="store_true",
help="byte-level BPE.")
parser.add_argument(
'--separator', type=bytes, default=b'@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s')")
parser.add_argument(
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
'--write-vocabulary', type=argparse.FileType('wb'), required=True, nargs = '+', default=None,
metavar='PATH', dest='vocab',
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
parser.add_argument(
@ -86,45 +85,78 @@ def learn_joint_bpe_and_vocab(args):
sys.stderr.write('Error: number of input files and vocabulary files must match\n')
sys.exit(1)
# read/write files as UTF-8
args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab]
if args.byte:
# read/write files as byte streams
args.input = [codecs.open(f.name, 'rb') for f in args.input]
args.vocab = [codecs.open(f.name, 'wb') for f in args.vocab]
else:
# read/write files as UTF-8
args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab]
args.separator = args.separator.decode('UTF-8') if not args.byte else args.separator
# get combined vocabulary of all input texts
full_vocab = Counter()
for f in args.input:
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers)
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers, is_bytes=args.byte)
f.seek(0)
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
if args.byte:
vocab_list = [key + b' ' + str(freq).encode('UTF-8') for (key, freq) in full_vocab.items()]
else:
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
# learn BPE on combined vocabulary
with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, total_symbols=args.total_symbols)
with codecs.open(args.output.name, encoding='UTF-8') as codes:
bpe = apply_bpe.BPE(codes, separator=args.separator)
if args.byte:
with open(args.output.name, 'wb') as output:
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, is_bytes=args.byte, total_symbols=args.total_symbols)
with open(args.output.name, 'rb') as codes:
bpe = apply_bpe.BPE(codes, separator=args.separator, is_bytes=args.byte)
else:
with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, is_bytes=args.byte, total_symbols=args.total_symbols)
with codecs.open(args.output.name, encoding='UTF-8') as codes:
bpe = apply_bpe.BPE(codes, separator=args.separator, is_bytes=args.byte)
# apply BPE to each training corpus and get vocabulary
for train_file, vocab_file in zip(args.input, args.vocab):
# read/write files as UTF-8
if not args.byte:
train_file = codecs.open(train_file.name, encoding='utf-8')
vocab_file = codecs.open(vocab_file.name, 'w', encoding='utf-8')
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp.close()
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
if args.byte:
tmpout = open(tmp.name, 'wb')
else:
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
train_file.seek(0)
bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers)
tmpout.close()
tmpin = codecs.open(tmp.name, encoding='UTF-8')
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers)
if args.byte:
tmpin = open(tmp.name, 'rb')
else:
tmpin = codecs.open(tmp.name, encoding='UTF-8')
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers, is_bytes=args.byte)
tmpin.close()
os.remove(tmp.name)
for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True):
vocab_file.write("{0} {1}\n".format(key, freq))
if args.byte:
vocab_file.write(key + b" " + str(freq).encode('utf-8') + b"\n")
else:
vocab_file.write("{0} {1}\n".format(key, freq))
train_file.close()
vocab_file.close()
@ -141,13 +173,12 @@ if __name__ == '__main__':
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
else:
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
print("Python 2 is deprecated. Use Python 3")
sys.exit(1)
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
parser = create_parser()
args = parser.parse_args()
@ -155,12 +186,6 @@ if __name__ == '__main__':
if args.num_workers <= 0:
args.num_workers = cpu_count()
if sys.version_info < (3, 0):
args.separator = args.separator.decode('UTF-8')
if args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
assert(len(args.input) == len(args.vocab))
learn_joint_bpe_and_vocab(args)

View File

@ -7,7 +7,7 @@ import codecs
import argparse
from .learn_bpe import learn_bpe
from .apply_bpe import BPE, read_vocabulary
from .apply_bpe import BPE, read_vocabulary, get_byte_mode
from .get_vocab import get_vocab
from .learn_joint_bpe_and_vocab import learn_joint_bpe_and_vocab
@ -16,9 +16,6 @@ from .apply_bpe import create_parser as create_apply_bpe_parser
from .get_vocab import create_parser as create_get_vocab_parser
from .learn_joint_bpe_and_vocab import create_parser as create_learn_joint_bpe_and_vocab_parser
# hack for python2/3 compatibility
argparse.open = io.open
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
@ -39,35 +36,44 @@ learn-joint-bpe-and-vocab: executes recommended workflow for joint BPE.""")
args = parser.parse_args()
if args.command == 'learn-bpe':
# read/write files as UTF-8
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
if args.byte:
if args.input.name == '<stdin>':
args.input = sys.stdin.buffer
if args.output.name == '<stdout>':
args.output = sys.stdout.buffer
else:
# read/write files as UTF-8
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose,
is_dict=args.dict_input, total_symbols=args.total_symbols)
is_dict=args.dict_input, is_bytes=args.byte, total_symbols=args.total_symbols)
elif args.command == 'apply-bpe':
# read/write files as UTF-8
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
if args.vocabulary:
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
is_bytes = get_byte_mode(args.codes.name)
if is_bytes:
if args.input.name == '<stdin>':
args.input = sys.stdin.buffer
if args.output.name == '<stdout>':
args.output = sys.stdout.buffer
else:
# read/write files as UTF-8
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
if args.vocabulary:
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
if args.vocabulary:
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
else:
vocabulary = None
if sys.version_info < (3, 0):
args.separator = args.separator.decode('UTF-8')
if args.glossaries:
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries, is_bytes)
for line in args.input:
args.output.write(bpe.process_line(line, args.dropout))
@ -80,18 +86,5 @@ learn-joint-bpe-and-vocab: executes recommended workflow for joint BPE.""")
get_vocab(args.input, args.output)
elif args.command == 'learn-joint-bpe-and-vocab':
learn_joint_bpe_and_vocab(args)
if sys.version_info < (3, 0):
args.separator = args.separator.decode('UTF-8')
else:
raise Exception('Invalid command provided')
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
else:
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)

View File

@ -103,7 +103,7 @@ class TestRegexIsolateGlossaries(unittest.TestCase):
test_case = (orig, exp)
self._run_test_case(test_case)
def encode_mock(segment, x2, x3, x4, x5, x6, x7, glosses, dropout):
def encode_mock(segment, x2, x3, x4, x5, x6, x7, glosses, x8, dropout):
if glosses.match(segment):
return (segment,)
else: