mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-25 21:34:58 +03:00
byte-level BPE. Still not fully tested, fails some unit tests with glossaries.
This commit is contained in:
parent
93f0c93ccd
commit
0607b5443f
@ -1,5 +1,8 @@
|
||||
CHANGELOG
|
||||
---------
|
||||
v0.3.9
|
||||
- byte-level BPE support
|
||||
- remove support for Python 2
|
||||
|
||||
v0.3.8:
|
||||
- multiprocessing support (get_vocab and apply_bpe)
|
||||
|
3
setup.py
3
setup.py
@ -11,7 +11,7 @@ def test_suite():
|
||||
|
||||
setup(
|
||||
name='subword_nmt',
|
||||
version='0.3.8',
|
||||
version='0.3.9',
|
||||
description='Unsupervised Word Segmentation for Neural Machine Translation and Text Generation',
|
||||
long_description=(codecs.open("README.md", encoding='utf-8').read() +
|
||||
"\n\n" + codecs.open("CHANGELOG.md", encoding='utf-8').read()),
|
||||
@ -25,7 +25,6 @@ setup(
|
||||
'Topic :: Text Processing',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
install_requires=['mock',
|
||||
|
@ -24,35 +24,44 @@ import warnings
|
||||
import random
|
||||
import tempfile
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
from contextlib import contextmanager
|
||||
|
||||
class BPE(object):
|
||||
|
||||
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
|
||||
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None, is_bytes=False):
|
||||
|
||||
codes.seek(0)
|
||||
offset=1
|
||||
|
||||
# check version information
|
||||
firstline = codes.readline()
|
||||
if firstline.startswith('#version:'):
|
||||
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
|
||||
if is_bytes:
|
||||
firstline = codes.readline()
|
||||
self.version = (0, 2)
|
||||
offset += 1
|
||||
else:
|
||||
self.version = (0, 1)
|
||||
codes.seek(0)
|
||||
firstline = codes.readline()
|
||||
if firstline.startswith('#version:'):
|
||||
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
|
||||
offset += 1
|
||||
else:
|
||||
self.version = (0, 1)
|
||||
codes.seek(0)
|
||||
|
||||
self.bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes.read().rstrip('\n').split('\n')) if (n < merges or merges == -1)]
|
||||
|
||||
self.strip_chars = b'\r\n ' if is_bytes else '\r\n '
|
||||
self.newline_char = b'\n' if is_bytes else '\n'
|
||||
self.split_char = b' ' if is_bytes else ' '
|
||||
|
||||
self.bpe_codes = [tuple(item.strip(self.strip_chars).split(self.split_char)) for (n, item) in enumerate(codes.read().rstrip(self.newline_char).split(self.newline_char)) if (n < merges or merges == -1)]
|
||||
|
||||
for i, item in enumerate(self.bpe_codes):
|
||||
if len(item) != 2:
|
||||
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, ' '.join(item)))
|
||||
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, self.split_char.join(item)))
|
||||
sys.stderr.write('The line should exist of exactly two subword units, separated by whitespace\n')
|
||||
sys.exit(1)
|
||||
|
||||
self.is_bytes = is_bytes
|
||||
|
||||
# some hacking to deal with duplicates (only consider first instance)
|
||||
self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
|
||||
|
||||
@ -62,22 +71,30 @@ class BPE(object):
|
||||
|
||||
self.vocab = vocab
|
||||
|
||||
self.glossaries = glossaries if glossaries else []
|
||||
if glossaries:
|
||||
if is_bytes:
|
||||
glossaries = [item.encode('utf-8') for item in glossaries]
|
||||
self.glossaries_regex = re.compile(b'^(' + b'|'.join(glossaries) + b')$')
|
||||
else:
|
||||
self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries)))
|
||||
else:
|
||||
self.glossaries_regex = None
|
||||
|
||||
self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries))) if glossaries else None
|
||||
self.glossaries = glossaries if glossaries else []
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def process_lines(self, filename, outfile, dropout=0, num_workers=1):
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
print("Parallel mode is only supported in Python3.")
|
||||
if sys.version_info < (3, 0) :
|
||||
print("Parallel mode is only supported in Python3")
|
||||
sys.exit(1)
|
||||
|
||||
if num_workers == 1:
|
||||
_process_lines(self, filename, outfile, dropout, 0, 0)
|
||||
elif num_workers > 1:
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
mode = 'rb' if self.is_bytes else 'r'
|
||||
with open_file(filename, mode) as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
chunk_size = int(size / num_workers)
|
||||
offsets = [0 for _ in range(num_workers + 1)]
|
||||
@ -103,7 +120,7 @@ class BPE(object):
|
||||
pool.close()
|
||||
pool.join()
|
||||
for i in range(num_workers):
|
||||
with open(res_files[i].name, encoding="utf-8") as fi:
|
||||
with open_file(res_files[i].name, mode) as fi:
|
||||
for line in fi:
|
||||
outfile.write(line)
|
||||
os.remove(res_files[i].name)
|
||||
@ -113,15 +130,15 @@ class BPE(object):
|
||||
def process_line(self, line, dropout=0):
|
||||
"""segment line, dealing with leading and trailing whitespace"""
|
||||
|
||||
out = ""
|
||||
out = b"" if self.is_bytes else ""
|
||||
|
||||
leading_whitespace = len(line)-len(line.lstrip('\r\n '))
|
||||
leading_whitespace = len(line)-len(line.lstrip(self.strip_chars))
|
||||
if leading_whitespace:
|
||||
out += line[:leading_whitespace]
|
||||
|
||||
out += self.segment(line, dropout)
|
||||
|
||||
trailing_whitespace = len(line)-len(line.rstrip('\r\n '))
|
||||
trailing_whitespace = len(line)-len(line.rstrip(self.strip_chars))
|
||||
if trailing_whitespace and trailing_whitespace != len(line):
|
||||
out += line[-trailing_whitespace:]
|
||||
|
||||
@ -129,8 +146,8 @@ class BPE(object):
|
||||
|
||||
def segment(self, sentence, dropout=0):
|
||||
"""segment single sentence (whitespace-tokenized string) with BPE encoding"""
|
||||
segments = self.segment_tokens(sentence.strip('\r\n ').split(' '), dropout)
|
||||
return ' '.join(segments)
|
||||
segments = self.segment_tokens(sentence.strip(self.strip_chars).split(self.split_char), dropout)
|
||||
return self.split_char.join(segments)
|
||||
|
||||
def segment_tokens(self, tokens, dropout=0):
|
||||
"""segment a sequence of tokens with BPE encoding"""
|
||||
@ -148,6 +165,7 @@ class BPE(object):
|
||||
self.version,
|
||||
self.cache,
|
||||
self.glossaries_regex,
|
||||
self.is_bytes,
|
||||
dropout)]
|
||||
|
||||
for item in new_word[:-1]:
|
||||
@ -160,15 +178,19 @@ class BPE(object):
|
||||
word_segments = [word]
|
||||
for gloss in self.glossaries:
|
||||
word_segments = [out_segments for segment in word_segments
|
||||
for out_segments in isolate_glossary(segment, gloss)]
|
||||
for out_segments in isolate_glossary(segment, gloss, self.is_bytes)]
|
||||
return word_segments
|
||||
|
||||
def _process_lines(bpe, filename, outfile, dropout, begin, end):
|
||||
|
||||
write_mode = 'wb' if bpe.is_bytes else 'w'
|
||||
read_mode = 'rb' if bpe.is_bytes else 'r'
|
||||
|
||||
if isinstance(outfile, str):
|
||||
fo = open(outfile, "w", encoding="utf-8")
|
||||
fo = open_file(outfile, write_mode)
|
||||
else:
|
||||
fo = outfile
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
with open_file(filename, read_mode) as f:
|
||||
f.seek(begin)
|
||||
line = f.readline()
|
||||
while line:
|
||||
@ -181,6 +203,17 @@ def _process_lines(bpe, filename, outfile, dropout, begin, end):
|
||||
if isinstance(outfile, str):
|
||||
fo.close()
|
||||
|
||||
@contextmanager
|
||||
def open_file(filename, mode):
|
||||
if mode in ('r', 'w'):
|
||||
f = open(filename, mode, encoding="utf-8")
|
||||
elif mode in ('rb', 'wb'):
|
||||
f = open(filename, mode)
|
||||
try:
|
||||
yield f
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
def create_parser(subparsers=None):
|
||||
|
||||
if subparsers:
|
||||
@ -193,11 +226,11 @@ def create_parser(subparsers=None):
|
||||
description="learn BPE-based word segmentation")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
||||
'--input', '-i', type=argparse.FileType('rb'), default=sys.stdin,
|
||||
metavar='PATH',
|
||||
help="Input file (default: standard input).")
|
||||
parser.add_argument(
|
||||
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
|
||||
'--codes', '-c', type=argparse.FileType('rb'), metavar='PATH',
|
||||
required=True,
|
||||
help="File with BPE codes (created by learn_bpe.py).")
|
||||
parser.add_argument(
|
||||
@ -206,14 +239,14 @@ def create_parser(subparsers=None):
|
||||
help="Use this many BPE operations (<= number of learned symbols)"+
|
||||
"default: Apply all the learned merge operations")
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
||||
'--output', '-o', type=argparse.FileType('wb'), default=sys.stdout,
|
||||
metavar='PATH',
|
||||
help="Output file (default: standard output)")
|
||||
parser.add_argument(
|
||||
'--separator', '-s', type=str, default='@@', metavar='STR',
|
||||
'--separator', '-s', type=bytes, default=b'@@', metavar='STR',
|
||||
help="Separator between non-final subword units (default: '%(default)s'))")
|
||||
parser.add_argument(
|
||||
'--vocabulary', type=argparse.FileType('r'), default=None,
|
||||
'--vocabulary', type=argparse.FileType('rb'), default=None,
|
||||
metavar="PATH",
|
||||
help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.")
|
||||
parser.add_argument(
|
||||
@ -240,7 +273,7 @@ def create_parser(subparsers=None):
|
||||
|
||||
return parser
|
||||
|
||||
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0):
|
||||
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, is_bytes=False, dropout=0):
|
||||
"""Encode word based on list of BPE merge operations, which are applied consecutively
|
||||
"""
|
||||
|
||||
@ -252,12 +285,16 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
|
||||
return (orig,)
|
||||
|
||||
if len(orig) == 1:
|
||||
return orig
|
||||
return (orig,)
|
||||
|
||||
if version == (0, 1):
|
||||
eow = b'</w>' if is_bytes else '</w>'
|
||||
|
||||
if is_bytes:
|
||||
word = list(map(lambda b: bytes([b]), orig[:-1])) + [orig[-1:] + eow]
|
||||
elif version == (0, 1):
|
||||
word = list(orig) + ['</w>']
|
||||
elif version == (0, 2): # more consistent handling of word-final segments
|
||||
word = list(orig[:-1]) + [orig[-1] + '</w>']
|
||||
word = list(orig[:-1]) + [orig[-1] + eow]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -277,7 +314,10 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
|
||||
|
||||
i = 0
|
||||
new_word = []
|
||||
bigram = ''.join(bigram)
|
||||
if is_bytes:
|
||||
bigram = b''.join(bigram)
|
||||
else:
|
||||
bigram = ''.join(bigram)
|
||||
for j in positions:
|
||||
# merges are invalid if they start before current position. This can happen if there are overlapping pairs: (x x x -> xx x)
|
||||
if j < i:
|
||||
@ -289,9 +329,9 @@ def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache,
|
||||
word = new_word
|
||||
|
||||
# don't print end-of-word symbols
|
||||
if word[-1] == '</w>':
|
||||
if word[-1] == eow:
|
||||
word = word[:-1]
|
||||
elif word[-1].endswith('</w>'):
|
||||
elif word[-1].endswith(eow):
|
||||
word[-1] = word[-1][:-4]
|
||||
|
||||
word = tuple(word)
|
||||
@ -367,7 +407,7 @@ def read_vocabulary(vocab_file, threshold):
|
||||
|
||||
return vocabulary
|
||||
|
||||
def isolate_glossary(word, glossary):
|
||||
def isolate_glossary(word, glossary, is_bytes=False):
|
||||
"""
|
||||
Isolate a glossary present inside a word.
|
||||
|
||||
@ -376,14 +416,34 @@ def isolate_glossary(word, glossary):
|
||||
For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is:
|
||||
['1934', 'USA', 'B', 'USA']
|
||||
"""
|
||||
|
||||
if is_bytes:
|
||||
pattern = b'^'+glossary+b'$'
|
||||
else:
|
||||
pattern = '^'+glossary+'$'
|
||||
strip_chars = b'\r\n ' if is_bytes else '\r\n '
|
||||
empty_string = b'' if is_bytes else ''
|
||||
|
||||
# regex equivalent of (if word == glossary or glossary not in word)
|
||||
if re.match('^'+glossary+'$', word) or not re.search(glossary, word):
|
||||
if re.match(pattern, word) or not re.search(glossary, word):
|
||||
return [word]
|
||||
else:
|
||||
segments = re.split(r'({})'.format(glossary), word)
|
||||
segments, ending = segments[:-1], segments[-1]
|
||||
if is_bytes:
|
||||
segments = re.split(rb'(' + glossary + rb')', word)
|
||||
segments, ending = segments[:-1], segments[-1:]
|
||||
else:
|
||||
segments = re.split(r'({})'.format(glossary), word)
|
||||
segments, ending = segments[:-1], segments[-1]
|
||||
segments = list(filter(None, segments)) # Remove empty strings in regex group.
|
||||
return segments + [ending.strip('\r\n ')] if ending != '' else segments
|
||||
return segments + [ending[0].strip(strip_chars)] if ending != empty_string else segments
|
||||
|
||||
# first line of BPE code file indicates if it is byte-level or UTF-8
|
||||
def get_byte_mode(code_file_name):
|
||||
firstline = open(code_file_name, mode='rb').readline()
|
||||
if firstline.endswith(b'byte\n'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -397,9 +457,8 @@ if __name__ == '__main__':
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
print("Python 2 is deprecated. Use Python 3")
|
||||
sys.exit(1)
|
||||
else:
|
||||
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
||||
@ -411,33 +470,37 @@ if __name__ == '__main__':
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
# read/write files as UTF-8
|
||||
# check if codes are bytes or UTF-8
|
||||
is_bytes = get_byte_mode(args.codes.name)
|
||||
|
||||
args.separator = args.separator.decode('UTF-8') if not is_bytes else args.separator
|
||||
|
||||
# read/write files as bytes or UTF-8, depending on mode
|
||||
|
||||
if is_bytes:
|
||||
if args.input.name == '<stdin>':
|
||||
args.input = sys.stdin.buffer
|
||||
if args.output.name == '<stdout>':
|
||||
args.output = sys.stdout.buffer
|
||||
else:
|
||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
if args.vocabulary:
|
||||
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
|
||||
|
||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
if args.vocabulary:
|
||||
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
|
||||
|
||||
if args.vocabulary:
|
||||
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
|
||||
else:
|
||||
vocabulary = None
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
if args.glossaries:
|
||||
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
|
||||
if args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
|
||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries, is_bytes)
|
||||
|
||||
if args.input.name == '<stdin>' or args.num_workers == 1:
|
||||
if args.num_workers > 1:
|
||||
|
@ -31,10 +31,6 @@ except ImportError:
|
||||
def tqdm(iterator, *args, **kwargs):
|
||||
return iterator
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
|
||||
def create_parser(subparsers=None):
|
||||
|
||||
if subparsers:
|
||||
@ -47,12 +43,11 @@ def create_parser(subparsers=None):
|
||||
description="learn BPE-based word segmentation")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
||||
'--input', '-i', type=argparse.FileType('rb'), default=sys.stdin,
|
||||
metavar='PATH',
|
||||
help="Input text (default: standard input).")
|
||||
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
||||
'--output', '-o', type=argparse.FileType('wb'), default=sys.stdout,
|
||||
metavar='PATH',
|
||||
help="Output file for BPE codes (default: standard output)")
|
||||
parser.add_argument(
|
||||
@ -61,6 +56,9 @@ def create_parser(subparsers=None):
|
||||
parser.add_argument(
|
||||
'--min-frequency', type=int, default=2, metavar='FREQ',
|
||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
|
||||
parser.add_argument(
|
||||
'--byte', '-b', action="store_true",
|
||||
help="byte-level BPE.")
|
||||
parser.add_argument('--dict-input', action="store_true",
|
||||
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
|
||||
parser.add_argument(
|
||||
@ -75,14 +73,18 @@ def create_parser(subparsers=None):
|
||||
|
||||
return parser
|
||||
|
||||
def get_vocabulary(fobj, is_dict=False, num_workers=1):
|
||||
def get_vocabulary(fobj, is_dict=False, is_bytes=False, num_workers=1):
|
||||
"""Read text and return dictionary that encodes vocabulary
|
||||
"""
|
||||
vocab = Counter()
|
||||
|
||||
strip_chars = b'\r\n ' if is_bytes else '\r\n '
|
||||
split_char = b' ' if is_bytes else ' '
|
||||
|
||||
if is_dict:
|
||||
for i, line in enumerate(fobj):
|
||||
try:
|
||||
word, count = line.strip('\r\n ').split(' ')
|
||||
word, count = line.strip(strip_chars).split(split_char)
|
||||
except:
|
||||
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
|
||||
sys.exit(1)
|
||||
@ -91,13 +93,13 @@ def get_vocabulary(fobj, is_dict=False, num_workers=1):
|
||||
if num_workers > 1:
|
||||
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
|
||||
for i, line in enumerate(fobj):
|
||||
for word in line.strip('\r\n ').split(' '):
|
||||
for word in line.strip(strip_chars).split(split_char):
|
||||
if word:
|
||||
vocab[word] += 1
|
||||
elif num_workers > 1:
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
print("Parallel mode is only supported in Python3.")
|
||||
if is_bytes:
|
||||
print('byte-level BPE not yet ported to parallel mode')
|
||||
sys.exit(1)
|
||||
|
||||
with open(fobj.name, encoding="utf8") as f:
|
||||
@ -231,13 +233,23 @@ def get_pair_statistics(vocab):
|
||||
return stats, indices
|
||||
|
||||
|
||||
def replace_pair(pair, vocab, indices):
|
||||
def replace_pair(pair, vocab, indices, is_bytes=False):
|
||||
"""Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
|
||||
|
||||
split_char = b' ' if is_bytes else ' '
|
||||
first, second = pair
|
||||
pair_str = ''.join(pair)
|
||||
pair_str = pair_str.replace('\\','\\\\')
|
||||
|
||||
if is_bytes:
|
||||
pair_str = b''.join(pair)
|
||||
# first = bytes((first,))
|
||||
# second = bytes((second,))
|
||||
pair_str = pair_str.replace(b'\\',b'\\\\')
|
||||
pattern = re.compile(rb'(?<!\S)' + re.escape(first + split_char + second) + rb'(?!\S)')
|
||||
else:
|
||||
pair_str = ''.join(pair)
|
||||
pair_str = pair_str.replace('\\','\\\\')
|
||||
pattern = re.compile(r'(?<!\S)' + re.escape(first + split_char + second) + r'(?!\S)')
|
||||
changes = []
|
||||
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
|
||||
if sys.version_info < (3, 0):
|
||||
iterator = indices[pair].iteritems()
|
||||
else:
|
||||
@ -246,9 +258,9 @@ def replace_pair(pair, vocab, indices):
|
||||
if freq < 1:
|
||||
continue
|
||||
word, freq = vocab[j]
|
||||
new_word = ' '.join(word)
|
||||
new_word = split_char.join(word)
|
||||
new_word = pattern.sub(pair_str, new_word)
|
||||
new_word = tuple(new_word.split(' '))
|
||||
new_word = tuple(new_word.split(split_char))
|
||||
|
||||
vocab[j] = (new_word, freq)
|
||||
changes.append((j, new_word, word, freq))
|
||||
@ -271,16 +283,22 @@ def prune_stats(stats, big_stats, threshold):
|
||||
big_stats[item] = freq
|
||||
|
||||
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False, num_workers=1):
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, is_bytes=False, total_symbols=False, num_workers=1):
|
||||
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
|
||||
"""
|
||||
|
||||
# version 0.2 changes the handling of the end-of-word token ('</w>');
|
||||
# version numbering allows bckward compatibility
|
||||
outfile.write('#version: 0.2\n')
|
||||
if is_bytes:
|
||||
outfile.write(b'#version: 0.2 byte\n')
|
||||
else:
|
||||
outfile.write('#version: 0.2\n')
|
||||
|
||||
vocab = get_vocabulary(infile, is_dict, num_workers)
|
||||
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
|
||||
vocab = get_vocabulary(infile, is_dict, is_bytes, num_workers)
|
||||
if is_bytes:
|
||||
vocab = dict([(tuple(map(lambda b: bytes([b]), x[:-1]))+(x[-1:]+b'</w>',) ,y) for (x,y) in vocab.items()])
|
||||
else:
|
||||
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
|
||||
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats, indices = get_pair_statistics(sorted_vocab)
|
||||
@ -319,8 +337,11 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
|
||||
|
||||
if verbose:
|
||||
sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
|
||||
outfile.write('{0} {1}\n'.format(*most_frequent))
|
||||
changes = replace_pair(most_frequent, sorted_vocab, indices)
|
||||
if is_bytes:
|
||||
outfile.write(most_frequent[0] + b' ' + most_frequent[1] + b'\n')
|
||||
else:
|
||||
outfile.write('{0} {1}\n'.format(*most_frequent))
|
||||
changes = replace_pair(most_frequent, sorted_vocab, indices, is_bytes)
|
||||
update_pair_statistics(most_frequent, changes, stats, indices)
|
||||
stats[most_frequent] = 0
|
||||
if not i % 100:
|
||||
@ -337,33 +358,34 @@ if __name__ == '__main__':
|
||||
DeprecationWarning
|
||||
)
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
else:
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.byte:
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
if sys.version_info < (3, 0) and args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
if sys.version_info < (3, 0):
|
||||
print("Python 2 is deprecated. Use Python 3")
|
||||
sys.exit(1)
|
||||
|
||||
# read/write files as UTF-8
|
||||
if args.input.name != '<stdin>':
|
||||
if args.input.name != '<stdin>' and not args.byte:
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
if args.output.name != '<stdout>' and not args.byte:
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols, num_workers=args.num_workers)
|
||||
if args.byte:
|
||||
if args.input.name == '<stdin>':
|
||||
args.input = sys.stdin.buffer
|
||||
if args.output.name == '<stdout>':
|
||||
args.output = sys.stdout.buffer
|
||||
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, is_bytes=args.byte, total_symbols=args.total_symbols, num_workers=args.num_workers)
|
||||
|
||||
# close files
|
||||
if args.input.name != '<stdin>':
|
||||
|
@ -32,10 +32,6 @@ else:
|
||||
from . import learn_bpe
|
||||
from . import apply_bpe
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
|
||||
def create_parser(subparsers=None):
|
||||
|
||||
if subparsers:
|
||||
@ -48,21 +44,24 @@ def create_parser(subparsers=None):
|
||||
description="learn BPE-based word segmentation")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), required=True, nargs = '+',
|
||||
'--input', '-i', type=argparse.FileType('rb'), required=True, nargs = '+',
|
||||
metavar='PATH',
|
||||
help="Input texts (multiple allowed).")
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), required=True,
|
||||
'--output', '-o', type=argparse.FileType('wb'), required=True,
|
||||
metavar='PATH',
|
||||
help="Output file for BPE codes.")
|
||||
parser.add_argument(
|
||||
'--symbols', '-s', type=int, default=10000,
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
'--separator', type=str, default='@@', metavar='STR',
|
||||
'--byte', '-b', action="store_true",
|
||||
help="byte-level BPE.")
|
||||
parser.add_argument(
|
||||
'--separator', type=bytes, default=b'@@', metavar='STR',
|
||||
help="Separator between non-final subword units (default: '%(default)s')")
|
||||
parser.add_argument(
|
||||
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
|
||||
'--write-vocabulary', type=argparse.FileType('wb'), required=True, nargs = '+', default=None,
|
||||
metavar='PATH', dest='vocab',
|
||||
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
|
||||
parser.add_argument(
|
||||
@ -86,45 +85,78 @@ def learn_joint_bpe_and_vocab(args):
|
||||
sys.stderr.write('Error: number of input files and vocabulary files must match\n')
|
||||
sys.exit(1)
|
||||
|
||||
# read/write files as UTF-8
|
||||
args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
|
||||
args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab]
|
||||
if args.byte:
|
||||
# read/write files as byte streams
|
||||
args.input = [codecs.open(f.name, 'rb') for f in args.input]
|
||||
args.vocab = [codecs.open(f.name, 'wb') for f in args.vocab]
|
||||
else:
|
||||
# read/write files as UTF-8
|
||||
args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
|
||||
args.vocab = [codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab]
|
||||
|
||||
args.separator = args.separator.decode('UTF-8') if not args.byte else args.separator
|
||||
|
||||
# get combined vocabulary of all input texts
|
||||
full_vocab = Counter()
|
||||
for f in args.input:
|
||||
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers)
|
||||
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers, is_bytes=args.byte)
|
||||
f.seek(0)
|
||||
|
||||
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
||||
if args.byte:
|
||||
vocab_list = [key + b' ' + str(freq).encode('UTF-8') for (key, freq) in full_vocab.items()]
|
||||
else:
|
||||
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
||||
|
||||
# learn BPE on combined vocabulary
|
||||
with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
|
||||
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, total_symbols=args.total_symbols)
|
||||
|
||||
with codecs.open(args.output.name, encoding='UTF-8') as codes:
|
||||
bpe = apply_bpe.BPE(codes, separator=args.separator)
|
||||
if args.byte:
|
||||
with open(args.output.name, 'wb') as output:
|
||||
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, is_bytes=args.byte, total_symbols=args.total_symbols)
|
||||
|
||||
with open(args.output.name, 'rb') as codes:
|
||||
bpe = apply_bpe.BPE(codes, separator=args.separator, is_bytes=args.byte)
|
||||
else:
|
||||
with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
|
||||
learn_bpe.learn_bpe(vocab_list, output, args.symbols, args.min_frequency, args.verbose, is_dict=True, is_bytes=args.byte, total_symbols=args.total_symbols)
|
||||
|
||||
with codecs.open(args.output.name, encoding='UTF-8') as codes:
|
||||
bpe = apply_bpe.BPE(codes, separator=args.separator, is_bytes=args.byte)
|
||||
|
||||
# apply BPE to each training corpus and get vocabulary
|
||||
for train_file, vocab_file in zip(args.input, args.vocab):
|
||||
|
||||
# read/write files as UTF-8
|
||||
if not args.byte:
|
||||
train_file = codecs.open(train_file.name, encoding='utf-8')
|
||||
vocab_file = codecs.open(vocab_file.name, 'w', encoding='utf-8')
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp.close()
|
||||
|
||||
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
||||
if args.byte:
|
||||
tmpout = open(tmp.name, 'wb')
|
||||
else:
|
||||
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
||||
|
||||
train_file.seek(0)
|
||||
bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers)
|
||||
|
||||
tmpout.close()
|
||||
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
||||
|
||||
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers)
|
||||
if args.byte:
|
||||
tmpin = open(tmp.name, 'rb')
|
||||
else:
|
||||
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
||||
|
||||
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers, is_bytes=args.byte)
|
||||
tmpin.close()
|
||||
os.remove(tmp.name)
|
||||
|
||||
for key, freq in sorted(vocab.items(), key=lambda x: x[1], reverse=True):
|
||||
vocab_file.write("{0} {1}\n".format(key, freq))
|
||||
if args.byte:
|
||||
vocab_file.write(key + b" " + str(freq).encode('utf-8') + b"\n")
|
||||
else:
|
||||
vocab_file.write("{0} {1}\n".format(key, freq))
|
||||
train_file.close()
|
||||
vocab_file.close()
|
||||
|
||||
@ -141,13 +173,12 @@ if __name__ == '__main__':
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
else:
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
||||
print("Python 2 is deprecated. Use Python 3")
|
||||
sys.exit(1)
|
||||
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
@ -155,12 +186,6 @@ if __name__ == '__main__':
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
if args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
|
||||
assert(len(args.input) == len(args.vocab))
|
||||
|
||||
learn_joint_bpe_and_vocab(args)
|
||||
|
@ -7,7 +7,7 @@ import codecs
|
||||
import argparse
|
||||
|
||||
from .learn_bpe import learn_bpe
|
||||
from .apply_bpe import BPE, read_vocabulary
|
||||
from .apply_bpe import BPE, read_vocabulary, get_byte_mode
|
||||
from .get_vocab import get_vocab
|
||||
from .learn_joint_bpe_and_vocab import learn_joint_bpe_and_vocab
|
||||
|
||||
@ -16,9 +16,6 @@ from .apply_bpe import create_parser as create_apply_bpe_parser
|
||||
from .get_vocab import create_parser as create_get_vocab_parser
|
||||
from .learn_joint_bpe_and_vocab import create_parser as create_learn_joint_bpe_and_vocab_parser
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
argparse.open = io.open
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
@ -39,35 +36,44 @@ learn-joint-bpe-and-vocab: executes recommended workflow for joint BPE.""")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'learn-bpe':
|
||||
# read/write files as UTF-8
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
if args.byte:
|
||||
if args.input.name == '<stdin>':
|
||||
args.input = sys.stdin.buffer
|
||||
if args.output.name == '<stdout>':
|
||||
args.output = sys.stdout.buffer
|
||||
else:
|
||||
# read/write files as UTF-8
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose,
|
||||
is_dict=args.dict_input, total_symbols=args.total_symbols)
|
||||
is_dict=args.dict_input, is_bytes=args.byte, total_symbols=args.total_symbols)
|
||||
elif args.command == 'apply-bpe':
|
||||
# read/write files as UTF-8
|
||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
if args.vocabulary:
|
||||
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
|
||||
is_bytes = get_byte_mode(args.codes.name)
|
||||
|
||||
if is_bytes:
|
||||
if args.input.name == '<stdin>':
|
||||
args.input = sys.stdin.buffer
|
||||
if args.output.name == '<stdout>':
|
||||
args.output = sys.stdout.buffer
|
||||
else:
|
||||
# read/write files as UTF-8
|
||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
if args.vocabulary:
|
||||
args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8')
|
||||
|
||||
if args.vocabulary:
|
||||
vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
|
||||
else:
|
||||
vocabulary = None
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
if args.glossaries:
|
||||
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
|
||||
|
||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
|
||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries, is_bytes)
|
||||
|
||||
for line in args.input:
|
||||
args.output.write(bpe.process_line(line, args.dropout))
|
||||
@ -80,18 +86,5 @@ learn-joint-bpe-and-vocab: executes recommended workflow for joint BPE.""")
|
||||
get_vocab(args.input, args.output)
|
||||
elif args.command == 'learn-joint-bpe-and-vocab':
|
||||
learn_joint_bpe_and_vocab(args)
|
||||
if sys.version_info < (3, 0):
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
else:
|
||||
raise Exception('Invalid command provided')
|
||||
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
else:
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
||||
|
@ -103,7 +103,7 @@ class TestRegexIsolateGlossaries(unittest.TestCase):
|
||||
test_case = (orig, exp)
|
||||
self._run_test_case(test_case)
|
||||
|
||||
def encode_mock(segment, x2, x3, x4, x5, x6, x7, glosses, dropout):
|
||||
def encode_mock(segment, x2, x3, x4, x5, x6, x7, glosses, x8, dropout):
|
||||
if glosses.match(segment):
|
||||
return (segment,)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user