mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-30 14:22:00 +03:00
Add parallel support (--num-workers)
This commit is contained in:
parent
ef69c1eb27
commit
bce5f8abcf
@ -22,7 +22,8 @@ import argparse
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
import random
|
import random
|
||||||
|
import tempfile
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
# hack for python2/3 compatibility
|
# hack for python2/3 compatibility
|
||||||
from io import open
|
from io import open
|
||||||
@ -67,6 +68,48 @@ class BPE(object):
|
|||||||
|
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
def process_lines(self, filename, outfile, dropout=0, num_workers=1):
|
||||||
|
|
||||||
|
if sys.version_info < (3, 0):
|
||||||
|
print("Parallel mode is only supported in Python3.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if num_workers == 1:
|
||||||
|
_process_lines(self, filename, outfile, dropout, 0, 0)
|
||||||
|
elif num_workers > 1:
|
||||||
|
with open(filename, encoding="utf-8") as f:
|
||||||
|
size = os.fstat(f.fileno()).st_size
|
||||||
|
chunk_size = int(size / num_workers)
|
||||||
|
offsets = [0 for _ in range(num_workers + 1)]
|
||||||
|
for i in range(1, num_workers):
|
||||||
|
f.seek(chunk_size * i)
|
||||||
|
pos = f.tell()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line = f.readline()
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pos -= 1
|
||||||
|
f.seek(pos)
|
||||||
|
offsets[i] = f.tell()
|
||||||
|
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||||
|
res_files = []
|
||||||
|
pool = Pool(processes=num_workers)
|
||||||
|
for i in range(num_workers):
|
||||||
|
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
tmp.close()
|
||||||
|
res_files.append(tmp)
|
||||||
|
pool.apply_async(_process_lines, (self, filename, tmp.name, dropout, offsets[i], offsets[i + 1]))
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
for i in range(num_workers):
|
||||||
|
with open(res_files[i].name, encoding="utf-8") as fi:
|
||||||
|
for line in fi:
|
||||||
|
outfile.write(line)
|
||||||
|
os.remove(res_files[i].name)
|
||||||
|
else:
|
||||||
|
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
|
||||||
|
|
||||||
def process_line(self, line, dropout=0):
|
def process_line(self, line, dropout=0):
|
||||||
"""segment line, dealing with leading and trailing whitespace"""
|
"""segment line, dealing with leading and trailing whitespace"""
|
||||||
|
|
||||||
@ -120,6 +163,23 @@ class BPE(object):
|
|||||||
for out_segments in isolate_glossary(segment, gloss)]
|
for out_segments in isolate_glossary(segment, gloss)]
|
||||||
return word_segments
|
return word_segments
|
||||||
|
|
||||||
|
def _process_lines(bpe, filename, outfile, dropout, begin, end):
|
||||||
|
if isinstance(outfile, str):
|
||||||
|
fo = open(outfile, "w", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
fo = outfile
|
||||||
|
with open(filename, encoding="utf-8") as f:
|
||||||
|
f.seek(begin)
|
||||||
|
line = f.readline()
|
||||||
|
while line:
|
||||||
|
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||||
|
if end > 0 and f.tell() > end:
|
||||||
|
break
|
||||||
|
fo.write(bpe.process_line(line, dropout))
|
||||||
|
line = f.readline()
|
||||||
|
if isinstance(outfile, str):
|
||||||
|
fo.close()
|
||||||
|
|
||||||
def create_parser(subparsers=None):
|
def create_parser(subparsers=None):
|
||||||
|
|
||||||
if subparsers:
|
if subparsers:
|
||||||
@ -173,6 +233,9 @@ def create_parser(subparsers=None):
|
|||||||
'--seed', type=int, default=None,
|
'--seed', type=int, default=None,
|
||||||
metavar="S",
|
metavar="S",
|
||||||
help="Random seed for the random number generators (e.g. for BPE dropout with --dropout).")
|
help="Random seed for the random number generators (e.g. for BPE dropout with --dropout).")
|
||||||
|
parser.add_argument(
|
||||||
|
'--num-workers', type=int, default=1,
|
||||||
|
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -345,6 +408,9 @@ if __name__ == '__main__':
|
|||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.num_workers <= 0:
|
||||||
|
args.num_workers = cpu_count()
|
||||||
|
|
||||||
# read/write files as UTF-8
|
# read/write files as UTF-8
|
||||||
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
args.codes = codecs.open(args.codes.name, encoding='utf-8')
|
||||||
if args.input.name != '<stdin>':
|
if args.input.name != '<stdin>':
|
||||||
@ -363,11 +429,19 @@ if __name__ == '__main__':
|
|||||||
args.separator = args.separator.decode('UTF-8')
|
args.separator = args.separator.decode('UTF-8')
|
||||||
if args.glossaries:
|
if args.glossaries:
|
||||||
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
|
args.glossaries = [g.decode('UTF-8') for g in args.glossaries]
|
||||||
|
if args.num_workers > 1:
|
||||||
|
args.num_workers = 1
|
||||||
|
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
|
bpe = BPE(args.codes, args.merges, args.separator, vocabulary, args.glossaries)
|
||||||
|
|
||||||
|
if args.input.name == '<stdin>' or args.num_workers == 1:
|
||||||
|
if args.num_workers > 1:
|
||||||
|
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
|
||||||
for line in args.input:
|
for line in args.input:
|
||||||
args.output.write(bpe.process_line(line, args.dropout))
|
args.output.write(bpe.process_line(line, args.dropout))
|
||||||
|
else:
|
||||||
|
bpe.process_lines(args.input.name, args.output, args.dropout, args.num_workers)
|
||||||
|
@ -21,6 +21,8 @@ import re
|
|||||||
import copy
|
import copy
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
import warnings
|
||||||
|
import tempfile
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
|
|
||||||
# hack for python2/3 compatibility
|
# hack for python2/3 compatibility
|
||||||
@ -49,39 +51,101 @@ def create_parser(subparsers=None):
|
|||||||
help="Output file for BPE codes (default: standard output)")
|
help="Output file for BPE codes (default: standard output)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--symbols', '-s', type=int, default=10000,
|
'--symbols', '-s', type=int, default=10000,
|
||||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--min-frequency', type=int, default=2, metavar='FREQ',
|
'--min-frequency', type=int, default=2, metavar='FREQ',
|
||||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
|
||||||
parser.add_argument('--dict-input', action="store_true",
|
parser.add_argument('--dict-input', action="store_true",
|
||||||
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
|
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--total-symbols', '-t', action="store_true",
|
'--total-symbols', '-t', action="store_true",
|
||||||
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
||||||
|
parser.add_argument(
|
||||||
|
'--num-workers', type=int, default=1,
|
||||||
|
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--verbose', '-v', action="store_true",
|
'--verbose', '-v', action="store_true",
|
||||||
help="verbose mode.")
|
help="verbose mode.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def get_vocabulary(fobj, is_dict=False):
|
def get_vocabulary(fobj, is_dict=False, num_workers=1):
|
||||||
"""Read text and return dictionary that encodes vocabulary
|
"""Read text and return dictionary that encodes vocabulary
|
||||||
"""
|
"""
|
||||||
vocab = Counter()
|
vocab = Counter()
|
||||||
for i, line in enumerate(fobj):
|
|
||||||
if is_dict:
|
if is_dict:
|
||||||
|
for i, line in enumerate(fobj):
|
||||||
try:
|
try:
|
||||||
word, count = line.strip('\r\n ').split(' ')
|
word, count = line.strip('\r\n ').split(' ')
|
||||||
except:
|
except:
|
||||||
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
|
print('Failed reading vocabulary file at line {0}: {1}'.format(i, line))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
vocab[word] += int(count)
|
vocab[word] += int(count)
|
||||||
else:
|
elif num_workers == 1 or fobj.name == '<stdin>':
|
||||||
|
if num_workers > 1:
|
||||||
|
warnings.warn("In parallel mode, the input cannot be STDIN. Using 1 processor instead.")
|
||||||
|
for i, line in enumerate(fobj):
|
||||||
for word in line.strip('\r\n ').split(' '):
|
for word in line.strip('\r\n ').split(' '):
|
||||||
if word:
|
if word:
|
||||||
vocab[word] += 1
|
vocab[word] += 1
|
||||||
|
elif num_workers > 1:
|
||||||
|
|
||||||
|
if sys.version_info < (3, 0):
|
||||||
|
print("Parallel mode is only supported in Python3.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
with open(fobj.name, encoding="utf8") as f:
|
||||||
|
size = os.fstat(f.fileno()).st_size
|
||||||
|
chunk_size = int(size / num_workers)
|
||||||
|
offsets = [0 for _ in range(num_workers + 1)]
|
||||||
|
for i in range(1, num_workers):
|
||||||
|
f.seek(chunk_size * i)
|
||||||
|
pos = f.tell()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line = f.readline()
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pos -= 1
|
||||||
|
f.seek(pos)
|
||||||
|
offsets[i] = f.tell()
|
||||||
|
assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||||
|
|
||||||
|
vocab_files = []
|
||||||
|
pool = Pool(processes=num_workers)
|
||||||
|
for i in range(num_workers):
|
||||||
|
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
tmp.close()
|
||||||
|
vocab_files.append(tmp)
|
||||||
|
pool.apply_async(_get_vocabulary, (fobj.name, tmp.name, offsets[i], offsets[i + 1]))
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
import pickle
|
||||||
|
for i in range(num_workers):
|
||||||
|
with open(vocab_files[i].name, 'rb') as f:
|
||||||
|
vocab += pickle.load(f)
|
||||||
|
os.remove(vocab_files[i].name)
|
||||||
|
else:
|
||||||
|
raise ValueError('`num_workers` is expected to be a positive number, but got {}.'.format(num_workers))
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
|
def _get_vocabulary(infile, outfile, begin, end):
|
||||||
|
import pickle
|
||||||
|
vocab = Counter()
|
||||||
|
with open(infile, encoding="utf8") as f:
|
||||||
|
f.seek(begin)
|
||||||
|
line = f.readline()
|
||||||
|
while line:
|
||||||
|
assert 0 <= f.tell() < 1e20, "Bad new line separator, e.g. '\\r'"
|
||||||
|
if end > 0 and f.tell() > end:
|
||||||
|
break
|
||||||
|
for word in line.strip('\r\n ').split(' '):
|
||||||
|
if word:
|
||||||
|
vocab[word] += 1
|
||||||
|
line = f.readline()
|
||||||
|
with open(outfile, 'wb') as f:
|
||||||
|
pickle.dump(vocab, f)
|
||||||
|
|
||||||
def update_pair_statistics(pair, changed, stats, indices):
|
def update_pair_statistics(pair, changed, stats, indices):
|
||||||
"""Minimally update the indices and frequency of symbol pairs
|
"""Minimally update the indices and frequency of symbol pairs
|
||||||
|
|
||||||
@ -200,7 +264,7 @@ def prune_stats(stats, big_stats, threshold):
|
|||||||
big_stats[item] = freq
|
big_stats[item] = freq
|
||||||
|
|
||||||
|
|
||||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
|
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False, num_workers=1):
|
||||||
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
|
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -208,7 +272,7 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
|
|||||||
# version numbering allows bckward compatibility
|
# version numbering allows bckward compatibility
|
||||||
outfile.write('#version: 0.2\n')
|
outfile.write('#version: 0.2\n')
|
||||||
|
|
||||||
vocab = get_vocabulary(infile, is_dict)
|
vocab = get_vocabulary(infile, is_dict, num_workers)
|
||||||
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
|
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])
|
||||||
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
@ -280,10 +344,17 @@ if __name__ == '__main__':
|
|||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.num_workers <= 0:
|
||||||
|
args.num_workers = cpu_count()
|
||||||
|
|
||||||
|
if sys.version_info < (3, 0) and args.num_workers > 1:
|
||||||
|
args.num_workers = 1
|
||||||
|
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||||
|
|
||||||
# read/write files as UTF-8
|
# read/write files as UTF-8
|
||||||
if args.input.name != '<stdin>':
|
if args.input.name != '<stdin>':
|
||||||
args.input = codecs.open(args.input.name, encoding='utf-8')
|
args.input = codecs.open(args.input.name, encoding='utf-8')
|
||||||
if args.output.name != '<stdout>':
|
if args.output.name != '<stdout>':
|
||||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||||
|
|
||||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)
|
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols, num_workers=args.num_workers)
|
||||||
|
@ -22,6 +22,7 @@ import argparse
|
|||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
#hack to get imports working if running this as a script, or within a package
|
#hack to get imports working if running this as a script, or within a package
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -56,20 +57,23 @@ def create_parser(subparsers=None):
|
|||||||
help="Output file for BPE codes.")
|
help="Output file for BPE codes.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--symbols', '-s', type=int, default=10000,
|
'--symbols', '-s', type=int, default=10000,
|
||||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--separator', type=str, default='@@', metavar='STR',
|
'--separator', type=str, default='@@', metavar='STR',
|
||||||
help="Separator between non-final subword units (default: '%(default)s'))")
|
help="Separator between non-final subword units (default: '%(default)s')")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
|
'--write-vocabulary', type=argparse.FileType('w'), required=True, nargs = '+', default=None,
|
||||||
metavar='PATH', dest='vocab',
|
metavar='PATH', dest='vocab',
|
||||||
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
|
help='Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--min-frequency', type=int, default=2, metavar='FREQ',
|
'--min-frequency', type=int, default=2, metavar='FREQ',
|
||||||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s)')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--total-symbols', '-t', action="store_true",
|
'--total-symbols', '-t', action="store_true",
|
||||||
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
||||||
|
parser.add_argument(
|
||||||
|
'--num-workers', type=int, default=1,
|
||||||
|
help="Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--verbose', '-v', action="store_true",
|
'--verbose', '-v', action="store_true",
|
||||||
help="verbose mode.")
|
help="verbose mode.")
|
||||||
@ -89,7 +93,7 @@ def learn_joint_bpe_and_vocab(args):
|
|||||||
# get combined vocabulary of all input texts
|
# get combined vocabulary of all input texts
|
||||||
full_vocab = Counter()
|
full_vocab = Counter()
|
||||||
for f in args.input:
|
for f in args.input:
|
||||||
full_vocab += learn_bpe.get_vocabulary(f)
|
full_vocab += learn_bpe.get_vocabulary(f, num_workers=args.num_workers)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
|
|
||||||
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
vocab_list = ['{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()]
|
||||||
@ -110,14 +114,12 @@ def learn_joint_bpe_and_vocab(args):
|
|||||||
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
|
||||||
|
|
||||||
train_file.seek(0)
|
train_file.seek(0)
|
||||||
for line in train_file:
|
bpe.process_lines(train_file.name, tmpout, num_workers=args.num_workers)
|
||||||
tmpout.write(bpe.segment(line).strip())
|
|
||||||
tmpout.write('\n')
|
|
||||||
|
|
||||||
tmpout.close()
|
tmpout.close()
|
||||||
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
tmpin = codecs.open(tmp.name, encoding='UTF-8')
|
||||||
|
|
||||||
vocab = learn_bpe.get_vocabulary(tmpin)
|
vocab = learn_bpe.get_vocabulary(tmpin, num_workers=args.num_workers)
|
||||||
tmpin.close()
|
tmpin.close()
|
||||||
os.remove(tmp.name)
|
os.remove(tmp.name)
|
||||||
|
|
||||||
@ -150,8 +152,14 @@ if __name__ == '__main__':
|
|||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.num_workers <= 0:
|
||||||
|
args.num_workers = cpu_count()
|
||||||
|
|
||||||
if sys.version_info < (3, 0):
|
if sys.version_info < (3, 0):
|
||||||
args.separator = args.separator.decode('UTF-8')
|
args.separator = args.separator.decode('UTF-8')
|
||||||
|
if args.num_workers > 1:
|
||||||
|
args.num_workers = 1
|
||||||
|
warnings.warn("Parallel mode is only supported in Python3. Using 1 processor instead.")
|
||||||
|
|
||||||
assert(len(args.input) == len(args.vocab))
|
assert(len(args.input) == len(args.vocab))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user