mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
4e3895be1c
Summary: Refactoring batch_by_size. You may be required to rebuild Cython components with: `python setup.py build_ext --inplace`. Reviewed By: myleott Differential Revision: D25705733 fbshipit-source-id: a263505276e3d820a9e44b93354ee5ace70d7fc5
137 lines
4.8 KiB
Python
137 lines
4.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from fairseq.data.data_utils_fast import batch_by_size_fn
|
|
from fairseq.data.data_utils_fast import batch_by_size_vec
|
|
|
|
|
|
class TestBatchBySize(unittest.TestCase):
|
|
@classmethod
|
|
def batch_by_size_baseline(
|
|
cls,
|
|
indices,
|
|
num_tokens_vec,
|
|
max_tokens,
|
|
max_sentences,
|
|
bsz_mult,
|
|
):
|
|
"""Simple, reliable and slow implementation of batch by size """
|
|
batches = []
|
|
start = 0
|
|
while start < len(indices):
|
|
for end in range(start + 1, len(indices) + 1):
|
|
max_val = max(num_tokens_vec[pos] for pos in range(start, end))
|
|
sent_count = end - start
|
|
num_tokens = max_val * sent_count
|
|
overflow = num_tokens > max_tokens > 0 or sent_count > max_sentences > 0
|
|
terminate = overflow or end == len(indices)
|
|
if overflow:
|
|
sent_count -= 1
|
|
if terminate:
|
|
if sent_count > bsz_mult:
|
|
sent_count = sent_count - sent_count % bsz_mult
|
|
batches.append(indices[start : start + sent_count])
|
|
start = start + sent_count
|
|
break
|
|
return batches
|
|
|
|
@classmethod
|
|
def _get_error_message(
|
|
cls, max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
|
|
):
|
|
return f"""Reference batch_by_size implementation should produce
|
|
same output as the baseline method.
|
|
Params:
|
|
max_sentences={max_sentences},
|
|
max_tokens={max_tokens},
|
|
bsz_mult={bsz_mult},
|
|
num_tokens_vec={num_tokens_vec},
|
|
expected_batches={validation},
|
|
returned_batches={results}"""
|
|
|
|
def _compare_results(
|
|
self,
|
|
indices_len,
|
|
batch_by_size_impl,
|
|
max_sentences,
|
|
max_tokens,
|
|
bsz_mult,
|
|
num_tokens_vec,
|
|
):
|
|
indices = np.array(list(range(indices_len)))
|
|
validation = self.batch_by_size_baseline(
|
|
indices,
|
|
num_tokens_vec,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
bsz_mult=bsz_mult,
|
|
)
|
|
results = batch_by_size_impl(
|
|
indices,
|
|
num_tokens_vec,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
bsz_mult=bsz_mult,
|
|
)
|
|
error_msg = self._get_error_message(
|
|
max_sentences, max_tokens, bsz_mult, num_tokens_vec, validation, results
|
|
)
|
|
self.assertEqual(len(validation), len(results), error_msg)
|
|
for first, second in zip(validation, results):
|
|
self.assertTrue(np.array_equal(first, second), error_msg)
|
|
|
|
def _run_compare_with_baseline_sweep(self, batch_by_size_impl):
|
|
"""Compare reference batch_by_size implementation with batch_by_size_baseline
|
|
across a dense grid of hyperparam values"""
|
|
MAX_MAX_TOKENS = 10
|
|
NUM_TOKENS_VECS_COUNT = 5
|
|
for indices_len in [10, 11]: # try odd and even len of indices
|
|
for max_sentences in range(0, indices_len + 2):
|
|
for max_tokens in range(0, MAX_MAX_TOKENS):
|
|
for bsz_mult in range(1, max(MAX_MAX_TOKENS, indices_len) + 2):
|
|
for _ in range(NUM_TOKENS_VECS_COUNT):
|
|
num_tokens_vec = np.random.randint(
|
|
0, max_tokens + 1, size=indices_len
|
|
)
|
|
self._compare_results(
|
|
indices_len,
|
|
batch_by_size_impl,
|
|
max_sentences,
|
|
max_tokens,
|
|
bsz_mult,
|
|
num_tokens_vec,
|
|
)
|
|
|
|
|
|
class TestBatchBySizeVec(TestBatchBySize):
|
|
def test_compare_with_baseline(self):
|
|
self._run_compare_with_baseline_sweep(batch_by_size_vec)
|
|
|
|
|
|
class TestBatchBySizeFn(TestBatchBySize):
|
|
def test_compare_with_baseline(self):
|
|
def batch_by_size_fn_wrapper(
|
|
indices,
|
|
num_tokens_vec,
|
|
max_tokens,
|
|
max_sentences,
|
|
bsz_mult,
|
|
):
|
|
def num_tokens_fn(idx):
|
|
return num_tokens_vec[idx]
|
|
|
|
return batch_by_size_fn(
|
|
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult
|
|
)
|
|
|
|
self._run_compare_with_baseline_sweep(batch_by_size_fn_wrapper)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|