mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-27 02:53:55 +03:00
Add parallel support (--num-workers)
This commit is contained in:
parent
ef69c1eb27
commit
bce5f8abcf
@ -22,7 +22,8 @@ import argparse
|
||||
import re
|
||||
import warnings
|
||||
import random
|
||||
|
||||
import tempfile
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
@ -67,6 +68,48 @@ class BPE(object):
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def process_lines(self, filename, outfile, dropout=0, num_workers=1):
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
print("Parallel mode is only supported in Python3.")
|
||||
sys.exit(1)
|
||||
|
||||
if num_workers == 1:
|
||||
_process_lines(self, filename, outfile, dropout, 0, 0)
|
||||
elif num_workers > 1:
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
chunk_size = int(size / num_workers)
|
||||
offsets = [0 for _ in range(num_workers + 1)]
|
||||
for i in range(1, num_workers):
|
||||
f.seek(chunk_size * i)
|
||||
pos = f.tell()
|
||||
while True:
|
||||
try:
|
||||
line = f.readline()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
pos -= 1
|
||||
f.seek(pos)
|
||||
offsets[i] = f.tell()
|
||||
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||
res_files = []
|
||||
pool = Pool(processes=num_workers)
|
||||
for i in range(num_workers):
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp.close()
|
||||
res_files.append(tmp)
|
||||
pool.apply_async(_process_lines, (self, filename, tmp.name, dropout, offsets[i], offsets[i + 1]))
|
||||
pool.close()
|
||||
pool.join()
|
||||
for i in range(num_workers):
|
||||
with open(res_files[i].name, encoding="utf-8") as fi:
|
||||
for line in fi:
|
||||
outfile.write(line)
|
||||
os.remove(res_files[i].name)
|
||||
else:
|
||||
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
|
||||
|
||||
def process_line(self, line, dropout=0):
|
||||
"""segment line, dealing with leading and trailing whitespace"""
|
||||
|
||||
@ -120,6 +163,23 @@ class BPE(object):
|
||||
for out_segments in isolate_glossary(segment, gloss)]
|
||||
return word_segments
|
||||
|
||||
def _process_lines(bpe, filename, outfile, dropout, begin, end):
|
||||
if isinstance(outfile, str):
|
||||
fo = open(outfile, "w", encoding="utf-8")
|
||||
else:
|
||||
fo = outfile
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
f.seek(begin)
|
||||
line = f.readline()
|
||||
while line:
|
||||
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||
if end > 0 and f.tell() > end:
|
||||
break
|
||||
fo.write(bpe.process_line(line, dropout))
|
||||
line = f.readline()
|
||||
if isinstance(outfile, str):
|
||||
fo.close()
|
||||
|
||||
def create_parser(subparsers=None):
|
||||
|
||||
if subparsers:
|
||||
@ -173,6 +233,9 @@ def create_parser(subparsers=None):
|
||||
'--seed', type=int, default=None,
|
||||
metavar="S",
|
||||
help="Random seed for the random number generators (e.g. for BPE dropout with --dropout).")
|
||||
parser.add_argument(
|
||||
'--num-workers', type=int, default=1,
|
||||
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||
|
||||
return parser
|
||||
|
||||
@ -345,6 +408,9 @@ if __name__ == '__main__':
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
# read/write files as UTF-8
|
||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||
if args.input.name != '<stdin>':
|
||||
@ -363,11 +429,19 @@ if __name__ == '__main__':
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
if args.glossaries:
|
||||
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
|
||||
if args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
|
||||
|
||||
for line in args.input:
|
||||
args.output.write(bpe.process_line(line, args.dropout))
|
||||
if args.input.name == '<stdin>' or args.num_workers == 1:
|
||||
if args.num_workers > 1:
|
||||
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
|
||||
for line in args.input:
|
||||
args.output.write(bpe.process_line(line, args.dropout))
|
||||
else:
|
||||
bpe.process_lines(args.input.name, args.output, args.dropout, args.num_workers)
|
||||
|
@ -21,6 +21,8 @@ import re
|
||||
import copy
|
||||
import argparse
|
||||
import warnings
|
||||
import tempfile
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
@ -49,39 +51,101 @@ def create_parser(subparsers=None):
|
||||
help="Output file for BPE codes (default: standard output)")
|
||||
parser.add_argument(
|
||||
'--symbols', '-s', type=int, default=10000,
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
'--min-frequency', type=int, default=2, metavar='FREQ',
|
||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
|
||||
parser.add_argument('--dict-input', action="store_true",
|
||||
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
|
||||
parser.add_argument(
|
||||
'--total-symbols', '-t', action="store_true",
|
||||
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
||||
parser.add_argument(
|
||||
'--num-workers', type=int, default=1,
|
||||
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
'--verbose', '-v', action="store_true",
|
||||
help="verbose mode.")
|
||||
|
||||
return parser
|
||||
|
||||
def get_vocabulary(fobj, is_dict=False):
|
||||
def get_vocabulary(fobj, is_dict=False, num_workers=1):
|
||||
"""Read text and return dictionary that encodes vocabulary
|
||||
"""
|
||||
vocab = Counter()
|
||||
for i, line in enumerate(fobj):
|
||||
if is_dict:
|
||||
if is_dict:
|
||||
for i, line in enumerate(fobj):
|
||||
try:
|
||||
word, count = line.strip('\r\n ').split(' ')
|
||||
except:
|
||||
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
|
||||
sys.exit(1)
|
||||
vocab[word] += int(count)
|
||||
else:
|
||||
elif num_workers == 1 or fobj.name == '<stdin>':
|
||||
if num_workers > 1:
|
||||
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
|
||||
for i, line in enumerate(fobj):
|
||||
for word in line.strip('\r\n ').split(' '):
|
||||
if word:
|
||||
vocab[word] += 1
|
||||
elif num_workers > 1:
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
print("Parallel mode is only supported in Python3.")
|
||||
sys.exit(1)
|
||||
|
||||
with open(fobj.name, encoding="utf8") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
chunk_size = int(size / num_workers)
|
||||
offsets = [0 for _ in range(num_workers + 1)]
|
||||
for i in range(1, num_workers):
|
||||
f.seek(chunk_size * i)
|
||||
pos = f.tell()
|
||||
while True:
|
||||
try:
|
||||
line = f.readline()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
pos -= 1
|
||||
f.seek(pos)
|
||||
offsets[i] = f.tell()
|
||||
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||
|
||||
vocab_files = []
|
||||
pool = Pool(processes=num_workers)
|
||||
for i in range(num_workers):
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp.close()
|
||||
vocab_files.append(tmp)
|
||||
pool.apply_async(_get_vocabulary, (fobj.name, tmp.name, offsets[i], offsets[i + 1]))
|
||||
pool.close()
|
||||
pool.join()
|
||||
import pickle
|
||||
for i in range(num_workers):
|
||||
with open(vocab_files[i].name, 'rb') as f:
|
||||
vocab += pickle.load(f)
|
||||
os.remove(vocab_files[i].name)
|
||||
else:
|
||||
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
|
||||
return vocab
|
||||
|
||||
def _get_vocabulary(infile, outfile, begin, end):
|
||||
import pickle
|
||||
vocab = Counter()
|
||||
with open(infile, encoding="utf8") as f:
|
||||
f.seek(begin)
|
||||
line = f.readline()
|
||||
while line:
|
||||
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||
if end > 0 and f.tell() > end:
|
||||
break
|
||||
for word in line.strip('\r\n ').split(' '):
|
||||
if word:
|
||||
vocab[word] += 1
|
||||
line = f.readline()
|
||||
with open(outfile, 'wb') as f:
|
||||
pickle.dump(vocab, f)
|
||||
|
||||
def update_pair_statistics(pair, changed, stats, indices):
|
||||
"""Minimally update the indices and frequency of symbol pairs
|
||||
|
||||
@ -200,7 +264,7 @@ def prune_stats(stats, big_stats, threshold):
|
||||
big_stats[item] = freq
|
||||
|
||||
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False, num_workers=1):
|
||||
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
|
||||
"""
|
||||
|
||||
@ -208,7 +272,7 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
|
||||
# version numbering allows bckward compatibility
|
||||
outfile.write('#version: 0.2\n')
|
||||
|
||||
vocab = get_vocabulary(infile, is_dict)
|
||||
vocab = get_vocabulary(infile, is_dict, num_workers)
|
||||
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
|
||||
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
@ -280,10 +344,17 @@ if __name__ == '__main__':
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
if sys.version_info < (3, 0) and args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
|
||||
# read/write files as UTF-8
|
||||
if args.input.name != '<stdin>':
|
||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols, num_workers=args.num_workers)
|
||||
|
@ -22,6 +22,7 @@ import argparse
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
#hack to get imports working if running this as a script, or within a package
|
||||
if __name__ == '__main__':
|
||||
@ -56,20 +57,23 @@ def create_parser(subparsers=None):
|
||||
help="Output file for BPE codes.")
|
||||
parser.add_argument(
|
||||
'--symbols', '-s', type=int, default=10000,
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
'--separator', type=str, default='@@', metavar='STR',
|
||||
help="Separator between non-final subword units (default: '%(default)s'))")
|
||||
help="Separator between non-final subword units (default: '%(default)s')")
|
||||
parser.add_argument(
|
||||
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
|
||||
metavar='PATH', dest='vocab',
|
||||
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
|
||||
parser.add_argument(
|
||||
'--min-frequency', type=int, default=2, metavar='FREQ',
|
||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
|
||||
parser.add_argument(
|
||||
'--total-symbols', '-t', action="store_true",
|
||||
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
||||
parser.add_argument(
|
||||
'--num-workers', type=int, default=1,
|
||||
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
'--verbose', '-v', action="store_true",
|
||||
help="verbose mode.")
|
||||
@ -89,7 +93,7 @@ def learn_joint_bpe_and_vocab(args):
|
||||
# get combined vocabulary of all input texts
|
||||
full_vocab = Counter()
|
||||
for f in args.input:
|
||||
full_vocab += learn_bpe.get_vocabulary(f)
|
||||
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers)
|
||||
f.seek(0)
|
||||
|
||||
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
||||
@ -110,14 +114,12 @@ def learn_joint_bpe_and_vocab(args):
|
||||
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
||||
|
||||
train_file.seek(0)
|
||||
for line in train_file:
|
||||
tmpout.write(bpe.segment(line).strip())
|
||||
tmpout.write('\n')
|
||||
bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers)
|
||||
|
||||
tmpout.close()
|
||||
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
||||
|
||||
vocab = learn_bpe.get_vocabulary(tmpin)
|
||||
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers)
|
||||
tmpin.close()
|
||||
os.remove(tmp.name)
|
||||
|
||||
@ -150,8 +152,14 @@ if __name__ == '__main__':
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.num_workers <= 0:
|
||||
args.num_workers = cpu_count()
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
args.separator = args.separator.decode('UTF-8')
|
||||
if args.num_workers > 1:
|
||||
args.num_workers = 1
|
||||
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||
|
||||
assert(len(args.input) == len(args.vocab))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user