batch_by_size refactoring: 100x speedup and optimization of memory footprint

Summary: Refactoring batch_by_size. You may be required to rebuild Cython components with: `python setup.py build_ext --inplace`.

Reviewed By: myleott

Differential Revision: D25705733

fbshipit-source-id: a263505276e3d820a9e44b93354ee5ace70d7fc5
This commit is contained in:
Ruslan Mavlyutov 2020-12-28 21:03:59 -08:00 committed by Facebook GitHub Bot
parent 982ec32976
commit 4e3895be1c
6 changed files with 287 additions and 55 deletions

View File

@ -80,8 +80,8 @@ def load_indexed_dataset(
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
from fairseq.data.concat_dataset import ConcatDataset
import fairseq.data.indexed_dataset as indexed_dataset
from fairseq.data.concat_dataset import ConcatDataset
datasets = []
for k in itertools.count():
@ -276,6 +276,7 @@ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_siz
def batch_by_size(
indices,
num_tokens_fn,
num_tokens_vec=None,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
@ -289,6 +290,8 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
num_tokens_vec (List[int], optional): precomputed vector of the number
of tokens for each index in indices (to enable faster batch generation)
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
@ -301,7 +304,8 @@ def batch_by_size(
"""
try:
from fairseq.data.data_utils_fast import (
batch_by_size_fast,
batch_by_size_fn,
batch_by_size_vec,
batch_fixed_shapes_fast,
)
except ImportError:
@ -317,14 +321,27 @@ def batch_by_size(
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
if fixed_shapes is None:
return batch_by_size_fast(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
if num_tokens_vec is None:
return batch_by_size_fn(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
else:
return batch_by_size_vec(
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort(

View File

@ -10,63 +10,115 @@ cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
from libcpp cimport bool as bool_t
ctypedef int64_t DTYPE_t
cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
if num_sentences == 0:
return 0
if max_sentences > 0 and num_sentences == max_sentences:
return 1
if max_tokens > 0 and num_tokens > max_tokens:
return 1
return 0
@cython.cdivision(True)
cpdef list batch_by_size_fast(
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
np.ndarray[int64_t, ndim=1] indices,
np.ndarray[int64_t, ndim=1] num_tokens_vec,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
np.zeros(indices_len, dtype=np.int32)
cdef int32_t[:] batches_ends_view = batches_ends
cdef int64_t[:] num_tokens_view = num_tokens_vec
cdef int32_t pos = 0
cdef int32_t new_batch_end = 0
cdef int64_t new_batch_max_tokens = 0
cdef int32_t new_batch_sentences = 0
cdef int64_t new_batch_num_tokens = 0
cdef bool_t overflow = False
cdef bool_t size_matches_with_bsz_mult = False
cdef int32_t batches_count = 0
cdef int32_t batch_start = 0
cdef int64_t tail_max_tokens = 0
cdef int64_t batch_max_tokens = 0
for pos in range(indices_len):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens = tail_max_tokens \
if tail_max_tokens > num_tokens_view[pos] \
else num_tokens_view[pos]
new_batch_end = pos + 1
new_batch_max_tokens = batch_max_tokens \
if batch_max_tokens > tail_max_tokens \
else tail_max_tokens
new_batch_sentences = new_batch_end - batch_start
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
overflow = (new_batch_sentences > max_sentences > 0 or
new_batch_num_tokens > max_tokens > 0)
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
new_batch_sentences % bsz_mult == 0)
if overflow:
tail_num_tokens = tail_max_tokens * \
(new_batch_end - batches_ends_view[batches_count])
tail_overflow = tail_num_tokens > max_tokens > 0
# In case of a tail overflow finalize two batches
if tail_overflow:
batches_count += 1
batches_ends_view[batches_count] = pos
tail_max_tokens = num_tokens_view[pos]
batch_start = batches_ends_view[batches_count]
batches_count += 1
new_batch_max_tokens = tail_max_tokens
if overflow or size_matches_with_bsz_mult:
batches_ends_view[batches_count] = new_batch_end
batch_max_tokens = new_batch_max_tokens
tail_max_tokens = 0
if batches_ends_view[batches_count] != indices_len:
batches_count += 1
# Memory and time-efficient split
return np.split(indices, batches_ends[:batches_count])
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
dtype=np.int64)
cdef DTYPE_t[:] indices_view = indices
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert max_tokens <= 0 or sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
cdef int64_t pos
for pos in range(indices_len):
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
max_sentences, bsz_mult,)
cdef _find_valid_shape(

View File

@ -3,10 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch.utils.data
from fairseq.data import data_utils
logger = logging.getLogger(__name__)
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
@ -54,6 +57,11 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
@ -129,9 +137,15 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
]
)
try:
num_tokens_vec = self.num_tokens_vec(indices).astype('int64')
except NotImplementedError:
num_tokens_vec = None
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
num_tokens_vec=num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,

View File

@ -408,6 +408,14 @@ class LanguagePairDataset(FairseqDataset):
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
sizes = self.src_sizes[indices]
if self.tgt_sizes is not None:
sizes = np.maximum(sizes, self.tgt_sizes[indices])
return sizes
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""

View File

@ -238,6 +238,11 @@ class SampledMultiDataset(FairseqDataset):
def num_tokens(self, index):
return self.sizes[index].max()
def num_tokens_vec(self, indices):
sizes_vec = self.sizes[np.array(indices)]
# max across all dimensions but first one
return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))
def size(self, index):
return self.sizes[index]

136
tests/test_data_utils.py Normal file
View File

@ -0,0 +1,136 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
from fairseq.data.data_utils_fast import batch_by_size_fn
from fairseq.data.data_utils_fast import batch_by_size_vec
class TestBatchBySize(unittest.TestCase):
@classmethod
def batch_by_size_baseline(
cls,
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
):
"""Simple, reliable and slow implementation of batch by size """
batches = []
start = 0
while start < len(indices):
for end in range(start + 1, len(indices) + 1):
max_val = max(num_tokens_vec[pos] for pos in range(start, end))
sent_count = end - start
num_tokens = max_val * sent_count
overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0
terminate = overflow or end == len(indices)
if overflow:
sent_count -= 1
if terminate:
if sent_count > bsz_mult:
sent_count = sent_count - sent_count % bsz_mult
batches.append(indices[start : start + sent_count])
start = start + sent_count
break
return batches
@classmethod
def _get_error_message(
cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
):
return f"""Reference batch_by_size implementation should produce
same output as the baseline method.
Params:
max_sentences={max_sentences},
max_tokens={max_tokens},
bsz_mult={bsz_mult},
num_tokens_vec={num_tokens_vec},
expected_batches={validation},
returned_batches={results}"""
def _compare_results(
self,
indices_len,
batch_by_size_impl,
max_sentences,
max_tokens,
bsz_mult,
num_tokens_vec,
):
indices = np.array(list(range(indices_len)))
validation = self.batch_by_size_baseline(
indices,
num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
bsz_mult=bsz_mult,
)
results = batch_by_size_impl(
indices,
num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
bsz_mult=bsz_mult,
)
error_msg = self._get_error_message(
max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
)
self.assertEqual(len(validation), len(results), error_msg)
for first, second in zip(validation, results):
self.assertTrue(np.array_equal(first, second), error_msg)
def _run_compare_with_baseline_sweep(self, batch_by_size_impl):
"""Compare reference batch_by_size implementation with batch_by_size_baseline
across a dense grid of hyperparam values"""
MAX_MAX_TOKENS = 10
NUM_TOKENS_VECS_COUNT = 5
for indices_len in [10, 11]: # try odd and even len of indices
for max_sentences in range(0, indices_len + 2):
for max_tokens in range(0, MAX_MAX_TOKENS):
for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2):
for _ in range(NUM_TOKENS_VECS_COUNT):
num_tokens_vec = np.random.randint(
0, max_tokens + 1, size=indices_len
)
self._compare_results(
indices_len,
batch_by_size_impl,
max_sentences,
max_tokens,
bsz_mult,
num_tokens_vec,
)
class TestBatchBySizeVec(TestBatchBySize):
def test_compare_with_baseline(self):
self._run_compare_with_baseline_sweep(batch_by_size_vec)
class TestBatchBySizeFn(TestBatchBySize):
def test_compare_with_baseline(self):
def batch_by_size_fn_wrapper(
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
):
def num_tokens_fn(idx):
return num_tokens_vec[idx]
return batch_by_size_fn(
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult
)
self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper)
if __name__ == "__main__":
unittest.main()