Cythonize token block dataset (#834)

Summary:
Cythonized token block dataset code, it's `> 100x` faster. Token block for entire `bookwiki+CC+stories+openweb` is just ~`39.9` seconds.

TODO:
1) I think, I can make it 2x more faster.
2) cleanup.

EDIT History:
~~First pass at parellelizing `token_block_dataset`. The code feels somewhat complicated and cluttered.
This is 2-3x faster though on my tests on `bookwiki` dataset with both `complete` and `complete_doc` modes.
myleott Can you take a look for correctness as I am still not 100% sure that I am not missing corner cases.~~
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/834

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Test workflow: f133816198

Reviewed By: myleott

Differential Revision: D16970257

Pulled By: myleott

fbshipit-source-id: ec45a308193c9e9f3e7075336c15df4723228d6f
This commit is contained in:
Naman Goyal 2019-08-23 07:31:18 -07:00 committed by Facebook Github Bot
parent 6e2bd794e0
commit 4fc39538ae
6 changed files with 286 additions and 155 deletions

View File

@ -12,6 +12,10 @@ import itertools
import os
import numpy as np
import sys
import types
from fairseq.data.data_utils_fast import batch_by_size_fast
def infer_language_pair(path):
@ -196,45 +200,13 @@ def batch_by_size(
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
return False
sample_len = 0
sample_lens = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)
def process_bpe_symbol(sentence: str, bpe_symbol: str):

View File

@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
if len(batch) == 0:
return 0
if len(batch) == max_sentences:
return 1
if num_tokens > max_tokens:
return 1
return 0
@cython.cdivision(True)
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
long max_tokens,
long max_sentences,
int bsz_mult,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches

View File

@ -3,11 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)
from fairseq.data import FairseqDataset, plasma_utils
@ -33,7 +36,6 @@ class TokenBlockDataset(FairseqDataset):
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""
def __init__(
self,
dataset,
@ -50,70 +52,22 @@ class TokenBlockDataset(FairseqDataset):
self.pad = pad
self.eos = eos
self.include_targets = include_targets
slice_indices = []
assert len(dataset) == len(sizes)
assert len(dataset) > 0
sizes = np.array(sizes, dtype=int)
if break_mode is None or break_mode == "none":
total_size = sum(sizes)
length = math.ceil(total_size / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, total_size)
return (start, end)
slice_indices = [block_at(i) for i in range(length)]
elif break_mode == "complete":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "complete_doc":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if (
(curr_size + sizes[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes[sz_idx] != document_sep_len
):
curr_size += sizes[sz_idx]
sz_idx += 1
else:
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "eos":
slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
slice_indices[0] = [0, sizes[0]]
if len(cumsum) > 1:
slice_indices[1:] = cumsum.unfold(0, 2, 1)
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
raise ValueError("Invalid break_mode: " + break_mode)
sizes = sizes.astype(np.int64)
slice_indices = np.array(slice_indices, dtype=int)
break_mode = break_mode if break_mode is not None else 'none'
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0
slice_indices = _get_slice_indices_fast(sizes, break_mode, block_size, document_sep_len)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
@ -130,23 +84,10 @@ class TokenBlockDataset(FairseqDataset):
1,
)
else:
ds = DatasetSearcher(sizes)
block_to_dataset_index = np.empty((len(slice_indices), 3), dtype=int)
for i, (s, e) in enumerate(slice_indices):
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
end_ds_idx, # ending index in dataset
)
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
@ -215,42 +156,3 @@ class TokenBlockDataset(FairseqDataset):
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)
class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
def __init__(self, sizes):
self.sizes = sizes
self.reset()
def reset(self):
self.current_index = 0 # index in underlying dataset
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
def seek(self, i):
assert i >= 0
def step():
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining > 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return True
return False
not_done = True
while not_done:
not_done = step()
assert self.current_i == i

View File

@ -0,0 +1,184 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from itertools import chain
from libc.math cimport ceil
cimport cython
cimport numpy as np
DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
cdef DTYPE_t total_size = sizes.sum()
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef DTYPE_t i
cdef DTYPE_t start
cdef DTYPE_t end
for i in range(length):
start = i * block_size
end = min(start + block_size, total_size)
slice_indices_view[i][0] = start
slice_indices_view[i][1] = end
return slice_indices
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
"""
Faster function to convert DTYPE_t list of list.
Only fast when there are huge number of rows and low number of columns.
"""
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
return flat.reshape((len(list_of_list), -1))
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
cdef DTYPE_t tok_idx = 0
cdef DTYPE_t sz_idx = 0
cdef DTYPE_t curr_size = 0
cdef DTYPE_t i = 0
cdef DTYPE_t length
cdef DTYPE_t total_size
cdef DTYPE_t[:] sizes_view = sizes
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
cdef list slice_indices_list = []
if break_mode is None or break_mode == 'none':
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
elif break_mode == 'complete':
while sz_idx < len(sizes_view):
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'complete_doc':
while sz_idx < len(sizes_view):
if (
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes_view[sz_idx] != document_sep_len
):
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
# Only keep non-empty documents.
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes_view[sz_idx] == document_sep_len:
tok_idx += sizes_view[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'eos':
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
cumsum = sizes.cumsum(axis=0)
slice_indices[1:, 0] = cumsum[:-1]
slice_indices[:, 1] = cumsum
else:
raise ValueError('Invalid break_mode: ' + break_mode)
return slice_indices
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
cdef DTYPE_t start_ds_idx
cdef DTYPE_t start_offset
cdef DTYPE_t end_ds_idx
cdef DTYPE_t i
cdef DTYPE_t s
cdef DTYPE_t e
cdef DatasetSearcher ds = DatasetSearcher(sizes)
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef Py_ssize_t x_max = slice_indices.shape[0]
for i in range(x_max):
s = slice_indices_view[i][0]
e = slice_indices_view[i][1]
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
return block_to_dataset_index
cdef class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
cdef DTYPE_t current_i
cdef DTYPE_t current_offset
cdef DTYPE_t current_index
cdef DTYPE_t[:] sizes
def __init__(self, DTYPE_t[:] sizes):
self.sizes = sizes
self.reset()
cdef reset(self):
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
self.current_index = 0 # index in underlying dataset
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int step(self, DTYPE_t i):
cdef DTYPE_t to_consume
cdef DTYPE_t remaining
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining > 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return 1
return 0
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef seek(self, DTYPE_t i):
cdef int not_done = 1
while not_done == 1:
not_done = self.step(i)
assert self.current_i == i

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq import tokenizer
@ -134,6 +135,7 @@ class FairseqTask(object):
indices = data_utils.filter_by_size(
indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
)
indices = np.fromiter(indices, dtype=np.int64, count=-1)
# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension
from Cython.Build import cythonize
import sys
@ -27,6 +28,8 @@ bleu = Extension(
extra_compile_args=extra_compile_args,
)
token_block_utils = cythonize("fairseq/data/token_block_utils_fast.pyx")
data_utils_fast = cythonize("fairseq/data/data_utils_fast.pyx", language="c++")
setup(
name='fairseq',
@ -52,7 +55,7 @@ setup(
'tqdm',
],
packages=find_packages(exclude=['scripts', 'tests']),
ext_modules=[bleu],
ext_modules=token_block_utils + data_utils_fast + [bleu],
test_suite='tests',
entry_points={
'console_scripts': [
@ -65,4 +68,5 @@ setup(
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
zip_safe=False,
)