Add parallel support (--num-workers)

This commit is contained in:
yimmon 2020-06-18 01:04:38 +08:00
parent ef69c1eb27
commit bce5f8abcf
3 changed files with 173 additions and 20 deletions

View File

@ -22,7 +22,8 @@ import argparse
import re
import warnings
import random
import tempfile
from multiprocessing import Pool, cpu_count
# hack for python2/3 compatibility
from io import open
@ -67,6 +68,48 @@ class BPE(object):
self.cache = {}
def process_lines(self, filename, outfile, dropout=0, num_workers=1):
if sys.version_info < (3, 0):
print("Parallel mode is only supported in Python3.")
sys.exit(1)
if num_workers == 1:
_process_lines(self, filename, outfile, dropout, 0, 0)
elif num_workers > 1:
with open(filename, encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = int(size / num_workers)
offsets = [0 for _ in range(num_workers + 1)]
for i in range(1, num_workers):
f.seek(chunk_size * i)
pos = f.tell()
while True:
try:
line = f.readline()
break
except UnicodeDecodeError:
pos -= 1
f.seek(pos)
offsets[i] = f.tell()
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
res_files = []
pool = Pool(processes=num_workers)
for i in range(num_workers):
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp.close()
res_files.append(tmp)
pool.apply_async(_process_lines, (self, filename, tmp.name, dropout, offsets[i], offsets[i + 1]))
pool.close()
pool.join()
for i in range(num_workers):
with open(res_files[i].name, encoding="utf-8") as fi:
for line in fi:
outfile.write(line)
os.remove(res_files[i].name)
else:
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
def process_line(self, line, dropout=0):
"""segment line, dealing with leading and trailing whitespace"""
@ -120,6 +163,23 @@ class BPE(object):
for out_segments in isolate_glossary(segment, gloss)]
return word_segments
def _process_lines(bpe, filename, outfile, dropout, begin, end):
if isinstance(outfile, str):
fo = open(outfile, "w", encoding="utf-8")
else:
fo = outfile
with open(filename, encoding="utf-8") as f:
f.seek(begin)
line = f.readline()
while line:
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
if end > 0 and f.tell() > end:
break
fo.write(bpe.process_line(line, dropout))
line = f.readline()
if isinstance(outfile, str):
fo.close()
def create_parser(subparsers=None):
if subparsers:
@ -173,6 +233,9 @@ def create_parser(subparsers=None):
'--seed', type=int, default=None,
metavar="S",
help="Random seed for the random number generators (e.g. for BPE dropout with --dropout).")
parser.add_argument(
'--num-workers', type=int, default=1,
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
return parser
@ -345,6 +408,9 @@ if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
if args.num_workers <= 0:
args.num_workers = cpu_count()
# read/write files as UTF-8
args.codes = codecs.open(args.codes.name, encoding='utf-8')
if args.input.name != '<stdin>':
@ -363,11 +429,19 @@ if __name__ == '__main__':
args.separator = args.separator.decode('UTF-8')
if args.glossaries:
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
if args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
if args.seed is not None:
random.seed(args.seed)
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
for line in args.input:
args.output.write(bpe.process_line(line, args.dropout))
if args.input.name == '<stdin>' or args.num_workers == 1:
if args.num_workers > 1:
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
for line in args.input:
args.output.write(bpe.process_line(line, args.dropout))
else:
bpe.process_lines(args.input.name, args.output, args.dropout, args.num_workers)

View File

@ -21,6 +21,8 @@ import re
import copy
import argparse
import warnings
import tempfile
from multiprocessing import Pool, cpu_count
from collections import defaultdict, Counter
# hack for python2/3 compatibility
@ -49,39 +51,101 @@ def create_parser(subparsers=None):
help="Output file for BPE codes (default: standard output)")
parser.add_argument(
'--symbols', '-s', type=int, default=10000,
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
parser.add_argument(
'--min-frequency', type=int, default=2, metavar='FREQ',
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
parser.add_argument('--dict-input', action="store_true",
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
parser.add_argument(
'--total-symbols', '-t', action="store_true",
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
parser.add_argument(
'--num-workers', type=int, default=1,
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
parser.add_argument(
'--verbose', '-v', action="store_true",
help="verbose mode.")
return parser
def get_vocabulary(fobj, is_dict=False):
def get_vocabulary(fobj, is_dict=False, num_workers=1):
"""Read text and return dictionary that encodes vocabulary
"""
vocab = Counter()
for i, line in enumerate(fobj):
if is_dict:
if is_dict:
for i, line in enumerate(fobj):
try:
word, count = line.strip('\r\n ').split(' ')
except:
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
sys.exit(1)
vocab[word] += int(count)
else:
elif num_workers == 1 or fobj.name == '<stdin>':
if num_workers > 1:
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
for i, line in enumerate(fobj):
for word in line.strip('\r\n ').split(' '):
if word:
vocab[word] += 1
elif num_workers > 1:
if sys.version_info < (3, 0):
print("Parallel mode is only supported in Python3.")
sys.exit(1)
with open(fobj.name, encoding="utf8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = int(size / num_workers)
offsets = [0 for _ in range(num_workers + 1)]
for i in range(1, num_workers):
f.seek(chunk_size * i)
pos = f.tell()
while True:
try:
line = f.readline()
break
except UnicodeDecodeError:
pos -= 1
f.seek(pos)
offsets[i] = f.tell()
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
vocab_files = []
pool = Pool(processes=num_workers)
for i in range(num_workers):
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp.close()
vocab_files.append(tmp)
pool.apply_async(_get_vocabulary, (fobj.name, tmp.name, offsets[i], offsets[i + 1]))
pool.close()
pool.join()
import pickle
for i in range(num_workers):
with open(vocab_files[i].name, 'rb') as f:
vocab += pickle.load(f)
os.remove(vocab_files[i].name)
else:
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
return vocab
def _get_vocabulary(infile, outfile, begin, end):
import pickle
vocab = Counter()
with open(infile, encoding="utf8") as f:
f.seek(begin)
line = f.readline()
while line:
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
if end > 0 and f.tell() > end:
break
for word in line.strip('\r\n ').split(' '):
if word:
vocab[word] += 1
line = f.readline()
with open(outfile, 'wb') as f:
pickle.dump(vocab, f)
def update_pair_statistics(pair, changed, stats, indices):
"""Minimally update the indices and frequency of symbol pairs
@ -200,7 +264,7 @@ def prune_stats(stats, big_stats, threshold):
big_stats[item] = freq
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False, num_workers=1):
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
"""
@ -208,7 +272,7 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
# version numbering allows bckward compatibility
outfile.write('#version: 0.2\n')
vocab = get_vocabulary(infile, is_dict)
vocab = get_vocabulary(infile, is_dict, num_workers)
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
@ -280,10 +344,17 @@ if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
if args.num_workers <= 0:
args.num_workers = cpu_count()
if sys.version_info < (3, 0) and args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
# read/write files as UTF-8
if args.input.name != '<stdin>':
args.input = codecs.open(args.input.name, encoding='utf-8')
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols, num_workers=args.num_workers)

View File

@ -22,6 +22,7 @@ import argparse
import tempfile
import warnings
from collections import Counter
from multiprocessing import cpu_count
#hack to get imports working if running this as a script, or within a package
if __name__ == '__main__':
@ -56,20 +57,23 @@ def create_parser(subparsers=None):
help="Output file for BPE codes.")
parser.add_argument(
'--symbols', '-s', type=int, default=10000,
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
parser.add_argument(
'--separator', type=str, default='@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
help="Separator between non-final subword units (default: '%(default)s')")
parser.add_argument(
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
metavar='PATH', dest='vocab',
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
parser.add_argument(
'--min-frequency', type=int, default=2, metavar='FREQ',
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
parser.add_argument(
'--total-symbols', '-t', action="store_true",
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
parser.add_argument(
'--num-workers', type=int, default=1,
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
parser.add_argument(
'--verbose', '-v', action="store_true",
help="verbose mode.")
@ -89,7 +93,7 @@ def learn_joint_bpe_and_vocab(args):
# get combined vocabulary of all input texts
full_vocab = Counter()
for f in args.input:
full_vocab += learn_bpe.get_vocabulary(f)
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers)
f.seek(0)
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
@ -110,14 +114,12 @@ def learn_joint_bpe_and_vocab(args):
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
train_file.seek(0)
for line in train_file:
tmpout.write(bpe.segment(line).strip())
tmpout.write('\n')
bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers)
tmpout.close()
tmpin = codecs.open(tmp.name, encoding='UTF-8')
vocab = learn_bpe.get_vocabulary(tmpin)
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers)
tmpin.close()
os.remove(tmp.name)
@ -150,8 +152,14 @@ if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
if args.num_workers <= 0:
args.num_workers = cpu_count()
if sys.version_info < (3, 0):
args.separator = args.separator.decode('UTF-8')
if args.num_workers > 1:
args.num_workers = 1
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
assert(len(args.input) == len(args.vocab))