Memory-Mapped IndexedDataset implementation (#589)

Summary:
Following discussion in https://github.com/pytorch/fairseq/issues/574:

 - Implemented MMapIndexedDataset and MMapIndexedDatasetBuilder compatible with IndexedDataset/IndexedDatasetBuilder
- Update scripts/read_binarized.py to support new MMapIndexedDataset
- Option '--raw-text' and '--lazy-load' replaced with '--dataset-impl' and moved the option definition custom task args to more high-level options.add_dataset_args() (more appropriate)
- Implemented also utils functions in indexed_dataset: make_dataset(), dataset_exists()
Pull Request resolved: https://github.com/pytorch/fairseq/pull/589

Differential Revision: D14597128

Pulled By: myleott

fbshipit-source-id: 4e92d99920cbaa52cfe5a0f1f5d9ae5c92d4268e
This commit is contained in:
Davide Caroselli 2019-05-07 07:06:16 -07:00 committed by Facebook Github Bot
parent e4edf27a97
commit a1c997bd9a
11 changed files with 289 additions and 123 deletions

View File

@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
@ -39,6 +39,7 @@ __all__ = [
'IndexedRawTextDataset',
'LanguagePairDataset',
'LMContextWindowDataset',
'MMapIndexedDataset',
'MonolingualDataset',
'NoisingDataset',
'RoundRobinZipDatasets',

View File

@ -4,14 +4,44 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import shutil
import struct
import numpy as np
import torch
def make_builder(out_file, impl):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
if impl == 'raw' and IndexedRawTextDataset.exists(path):
assert dictionary is not None
return IndexedRawTextDataset(path, dictionary)
elif impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
return None
def dataset_exists(path, impl):
if impl == 'raw':
return IndexedRawTextDataset.exists(path)
elif impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
@ -37,6 +67,7 @@ def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
@ -100,8 +131,8 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
@property
@ -135,7 +166,7 @@ class IndexedCachedDataset(IndexedDataset):
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
@ -149,7 +180,7 @@ class IndexedCachedDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
np.copyto(a, self.cache[ptx: ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
@ -262,3 +293,169 @@ class IndexedDatasetBuilder(object):
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
_warmup_mmap_file(path)
self._bin_buffer = memoryview(np.memmap(path, mode='r', order='C'))
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path):
self._path = path
self._index = self.Index(index_file_path(self._path))
_warmup_mmap_file(data_file_path(self._path))
self._bin_buffer = memoryview(np.memmap(data_file_path(self._path), mode='r', order='C'))
def __len__(self):
return len(self._index)
def __getitem__(self, i):
ptr, size = self._index[i]
tensor = torch.from_numpy(np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr))
if tensor.dtype == torch.int64:
return tensor
else:
return tensor.long()
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes)

View File

@ -198,9 +198,8 @@ def add_preprocess_args(parser):
help="number of source words to retain")
group.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
group.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
group.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
group.add_argument("--only-source", action="store_true",
@ -226,6 +225,8 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],

View File

@ -17,9 +17,7 @@ from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import (
ConcatDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
indexed_dataset,
TokenBlockDataset,
)
@ -118,14 +116,11 @@ class CrossLingualLMTask(FairseqTask):
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
ds = indexed_dataset.make_dataset(
path, impl=self.args.dataset_impl, fix_lua_indexing=True,
dictionary=self.dictionary,
)
if ds is None:
if k > 0:
break
else:

View File

@ -8,21 +8,19 @@
import itertools
import os
import torch
import numpy as np
import torch
from fairseq import utils
from fairseq.data import (
ConcatDataset,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TransformEosDataset,
TruncatedDictionary,
indexed_dataset
)
from . import FairseqTask, register_task
@ -101,6 +99,13 @@ class LanguageModelingTask(FairseqTask):
Args:
args (argparse.Namespace): parsed command-line arguments
"""
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
dictionary = None
output_dictionary = None
if args.data:
@ -154,15 +159,10 @@ class LanguageModelingTask(FairseqTask):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dictionary)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if ds is None:
if k > 0:
break
else:

View File

@ -11,17 +11,15 @@ import os
import torch
from fairseq import options
from fairseq import options, utils
from fairseq.data import (
BacktranslationDataset,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
indexed_dataset,
)
from fairseq.models import FairseqMultiModel
@ -78,7 +76,7 @@ class MultilingualTranslationTask(FairseqTask):
help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
@ -122,6 +120,12 @@ class MultilingualTranslationTask(FairseqTask):
def prepare(cls, args, **kargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
args.lang_pairs = args.lang_pairs.split(',')
sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
@ -196,21 +200,7 @@ class MultilingualTranslationTask(FairseqTask):
def split_exists(split, src, tgt, lang):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
src_datasets, tgt_datasets = {}, {}
for lang_pair in self.args.lang_pairs:
@ -221,8 +211,10 @@ class MultilingualTranslationTask(FairseqTask):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
src_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dicts[tgt])
print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:

View File

@ -8,15 +8,13 @@
import itertools
import os
from fairseq import options
from fairseq import options, utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
indexed_dataset
)
from . import FairseqTask, register_task
@ -56,7 +54,7 @@ class TranslationTask(FairseqTask):
help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left')
@ -84,6 +82,12 @@ class TranslationTask(FairseqTask):
"""
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
paths = args.data.split(':')
assert len(paths) > 0
@ -116,21 +120,7 @@ class TranslationTask(FairseqTask):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
src_datasets = []
tgt_datasets = []
@ -150,8 +140,10 @@ class TranslationTask(FairseqTask):
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.src_dict))
tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))

View File

@ -8,17 +8,7 @@ import contextlib
import torch
from fairseq import modules, options, utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from fairseq import modules, utils
from . import register_task
from .translation import TranslationTask
@ -40,8 +30,8 @@ class TranslationMoETask(TranslationTask):
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
Args:
src_dict (Dictionary): dictionary for the source language
tgt_dict (Dictionary): dictionary for the target language
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::

View File

@ -129,9 +129,7 @@ def main(args):
)
pool.close()
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
merge_result(
Binarizer.binarize(
input_file, vocab, lambda t: ds.add_item(t),
@ -161,15 +159,15 @@ def main(args):
)
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == "binary":
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
elif args.output_format == "raw":
if args.dataset_impl == "raw":
# Copy original text file to destination folder
output_text_file = dest_path(
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
else:
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
def make_all(lang, vocab):
if args.trainpref:
@ -233,9 +231,7 @@ def main(args):
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
def consumer(tensor):
ds.add_item(tensor)
@ -263,15 +259,6 @@ def get_offsets(input_file, num_workers):
return Binarizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):
ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
for file in files:
ds.merge_file_(file)
os.remove(indexed_dataset.data_file_path(file))
os.remove(indexed_dataset.index_file_path(file))
ds.finalize("{}.idx".format(outpath))
def cli_main():
parser = options.get_preprocessing_parser()
args = parser.parse_args()

View File

@ -8,29 +8,39 @@
import argparse
from fairseq.data import dictionary
from fairseq.data import IndexedDataset
from fairseq.data import Dictionary
from fairseq.data import indexed_dataset
def get_parser():
parser = argparse.ArgumentParser(
description='writes text from binarized file to stdout')
# fmt: off
parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words')
parser.add_argument('--dataset-impl', help='dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='lazy')
parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
# fmt: on
return parser
def main(args):
dict = dictionary.Dictionary.load(args.dict)
ds = IndexedDataset(args.input, fix_lua_indexing=True)
for tensor_line in ds:
print(dict.string(tensor_line))
def main():
parser = get_parser()
args = parser.parse_args()
dictionary = Dictionary.load(args.dict) if args.dict is not None else None
dataset = indexed_dataset.make_dataset(args.input, impl=args.dataset_impl,
fix_lua_indexing=True, dictionary=dictionary)
for tensor_line in dataset:
if dictionary is None:
line = ' '.join([str(int(x)) for x in tensor_line])
else:
line = dictionary.string(tensor_line)
print(line)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)
main()

View File

@ -38,9 +38,9 @@ class TestTranslation(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--output-format', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--raw-text'])
generate_main(data_dir, ['--raw-text'])
preprocess_translation_data(data_dir, ['--dataset-impl', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw'])
generate_main(data_dir, ['--dataset-impl', 'raw'])
def test_fp16(self):
with contextlib.redirect_stdout(StringIO()):
@ -418,7 +418,8 @@ def train_masked_language_model(data_dir, arch):
"--no-progress-bar",
"--distributed-world-size",
"1",
"--raw-text",
"--dataset-impl",
"raw",
],
)
train.main(train_args)