Move string line encoding logic from tokenizer to Dictionary (unified diff). (#541)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/541

Just a combo of a stacked pair D14057943 & D14176011,
Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo

Differential Revision: D14251048

fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8
This commit is contained in:
Vladimir Karpukhin 2019-02-28 09:15:35 -08:00 committed by Facebook Github Bot
parent bc919276a1
commit f296824f40
13 changed files with 204 additions and 196 deletions

View File

@ -209,7 +209,6 @@ following contents::
from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tokenizer import Tokenizer
@register_task('simple_classification')
@ -253,8 +252,8 @@ following contents::
sentence = line.strip()
# Tokenize the sentence, splitting on spaces
tokens = Tokenizer.tokenize(
sentence, self.input_vocab, add_if_not_exist=False,
tokens = self.input_vocab.encode_line(
sentence, add_if_not_exist=False,
)
sentences.append(tokens)
@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents::
# Tokenize into characters
chars = ' '.join(list(sentence.strip()))
tokens = Tokenizer.tokenize(
chars, task.source_dictionary, add_if_not_exist=False,
tokens = task.source_dictionary.encode_line(
chars, add_if_not_exist=False,
)
# Build mini-batch to feed to the model

67
fairseq/binarizer.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os
from fairseq.tokenizer import tokenize_line
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Binarizer:
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = dict.encode_line(
line=line,
line_tokenizer=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets

View File

@ -213,3 +213,11 @@ def batch_by_size(
if len(batch) > 0:
yield batch
def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace('\u2581', ' ').strip()
elif bpe_symbol is not None:
sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip()
return sentence

View File

@ -6,10 +6,15 @@
# can be found in the PATENTS file in the same directory.
from collections import Counter
from multiprocessing import Pool
import os
import torch
from fairseq.tokenizer import tokenize_line
from fairseq.binarizer import safe_readline
from fairseq.data import data_utils
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
@ -57,14 +62,8 @@ class Dictionary(object):
else:
return self[i]
if bpe_symbol == 'sentencepiece':
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
return data_utils.process_bpe_symbol(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
@ -181,31 +180,104 @@ class Dictionary(object):
"rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
lines = f.readlines()
indices_start_line = d._load_meta(lines)
for line in lines[indices_start_line:]:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx+1:])
count = int(line[idx + 1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f):
"""Stores dictionary into a text file"""
def _save(self, f, kv_iterator):
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
for k, v in kv_iterator:
print('{} {}'.format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
class TruncatedDictionary(object):

View File

@ -11,8 +11,6 @@ import struct
import numpy as np
import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
@ -171,8 +169,8 @@ class IndexedRawTextDataset(torch.utils.data.Dataset):
with open(path, 'r', encoding='utf-8') as f:
for line in f:
self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
tokens = dictionary.encode_line(
line, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)

View File

@ -9,7 +9,6 @@ import torch
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer
class FairseqTask(object):
@ -52,7 +51,7 @@ class FairseqTask(object):
"""
d = Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d

View File

@ -5,13 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
from multiprocessing import Pool
import os
import re
import torch
SPACE_NORMALIZER = re.compile(r"\s+")
@ -19,124 +14,3 @@ def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Tokenizer:
@staticmethod
def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Tokenizer.add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod
def binarize(
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = Tokenizer.tokenize(
line=line,
dict=dict,
tokenize=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = tokenize(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = dict.add_symbol(word)
else:
idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = dict.eos_index
return ids

View File

@ -304,7 +304,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment

View File

@ -165,8 +165,7 @@ def main(args):
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:

View File

@ -38,7 +38,7 @@ def buffered_read(input, buffer_size):
def make_batches(lines, args, task, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
for src_str in lines
]
lengths = torch.LongTensor([t.numel() for t in tokens])

View File

@ -11,15 +11,15 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter
from itertools import zip_longest
import os
import shutil
from fairseq import options, tasks
from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool
from fairseq.utils import import_user_module
import os
import shutil
def main(args):
@ -95,9 +95,8 @@ def main(args):
if target and tgt_dict is not None:
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = task.load_dictionary(dict_path(lang))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
n_seq_tok = [0, 0]
replaced = Counter()
@ -109,7 +108,7 @@ def main(args):
input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else ""
)
offsets = Tokenizer.find_offsets(input_file, num_workers)
offsets = Binarizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
@ -120,13 +119,13 @@ def main(args):
(
args,
input_file,
dict,
vocab,
prefix,
lang,
offsets[worker_id],
offsets[worker_id + 1],
offsets[worker_id + 1]
),
callback=merge_result,
callback=merge_result
)
pool.close()
@ -134,8 +133,9 @@ def main(args):
dataset_dest_file(args, output_prefix, lang, "bin")
)
merge_result(
Tokenizer.binarize(
input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
Binarizer.binarize(
input_file, vocab, lambda t: ds.add_item(t),
offset=0, end=offsets[1]
)
)
if num_workers > 1:
@ -156,13 +156,13 @@ def main(args):
n_seq_tok[0],
n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1],
dict.unk_word,
vocab.unk_word,
)
)
def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == "binary":
make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
elif args.output_format == "raw":
# Copy original text file to destination folder
output_text_file = dest_path(
@ -171,21 +171,21 @@ def main(args):
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang):
def make_all(lang, vocab):
if args.trainpref:
make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(validpref, outprefix, lang)
make_dataset(vocab, validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(testpref, outprefix, lang)
make_dataset(vocab, testpref, outprefix, lang)
make_all(args.source_lang)
make_all(args.source_lang, src_dict)
if target:
make_all(args.target_lang)
make_all(args.target_lang, tgt_dict)
print("| Wrote preprocessed data to {}".format(args.destdir))
@ -198,8 +198,8 @@ def main(args):
with open(src_file_name, "r", encoding='utf-8') as src_file:
with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
@ -232,7 +232,7 @@ def main(args):
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=True):
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
@ -240,14 +240,8 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=
def consumer(tensor):
ds.add_item(tensor)
res = Tokenizer.binarize(
filename,
dict,
consumer,
offset=offset,
end=end,
append_eos=append_eos
)
res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos,
offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res
@ -266,7 +260,7 @@ def dataset_dest_file(args, output_prefix, lang, extension):
def get_offsets(input_file, num_workers):
return Tokenizer.find_offsets(input_file, num_workers)
return Binarizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):

View File

@ -13,7 +13,7 @@ import argparse
import os
import sys
from fairseq import bleu, tokenizer
from fairseq import bleu
from fairseq.data import dictionary
@ -62,8 +62,8 @@ def main():
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))

View File

@ -11,7 +11,6 @@ import unittest
import torch
from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase):
@ -39,12 +38,12 @@ class TestDictionary(unittest.TestCase):
# build dictionary
d = Dictionary()
for line in txt:
Tokenizer.tokenize(line, d, add_if_not_exist=True)
d.encode_line(line, add_if_not_exist=True)
def get_ids(dictionary):
ids = []
for line in txt:
ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
ids.append(dictionary.encode_line(line, add_if_not_exist=False))
return ids
def assertMatch(ids, ref_ids):