mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
a48f235636
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import collections
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from fairseq.data import ListDataset, ResamplingDataset
|
|
|
|
|
|
class TestResamplingDataset(unittest.TestCase):
|
|
def setUp(self):
|
|
self.strings = ["ab", "c", "def", "ghij"]
|
|
self.weights = [4.0, 2.0, 7.0, 1.5]
|
|
self.size_ratio = 2
|
|
self.dataset = ListDataset(
|
|
self.strings, np.array([len(s) for s in self.strings])
|
|
)
|
|
|
|
def _test_common(self, resampling_dataset, iters):
|
|
assert len(self.dataset) == len(self.strings) == len(self.weights)
|
|
assert len(resampling_dataset) == self.size_ratio * len(self.strings)
|
|
|
|
results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
|
|
|
|
totalfreqs = 0
|
|
freqs = collections.defaultdict(int)
|
|
|
|
for epoch_num in range(iters):
|
|
resampling_dataset.set_epoch(epoch_num)
|
|
|
|
indices = resampling_dataset.ordered_indices()
|
|
assert len(indices) == len(resampling_dataset)
|
|
|
|
prev_size = -1
|
|
|
|
for i in indices:
|
|
cur_size = resampling_dataset.size(i)
|
|
# Make sure indices map to same sequences within an epoch
|
|
assert resampling_dataset[i] == resampling_dataset[i]
|
|
|
|
# Make sure length of sequence is correct
|
|
assert cur_size == len(resampling_dataset[i])
|
|
|
|
freqs[resampling_dataset[i]] += 1
|
|
totalfreqs += 1
|
|
|
|
if prev_size > cur_size:
|
|
results["ordered_by_size"] = False
|
|
|
|
prev_size = cur_size
|
|
|
|
assert set(freqs.keys()) == set(self.strings)
|
|
for s, weight in zip(self.strings, self.weights):
|
|
freq = freqs[s] / totalfreqs
|
|
expected_freq = weight / sum(self.weights)
|
|
results["max_distribution_diff"] = max(
|
|
results["max_distribution_diff"], abs(expected_freq - freq)
|
|
)
|
|
|
|
return results
|
|
|
|
def test_resampling_dataset_batch_by_size_false(self):
|
|
resampling_dataset = ResamplingDataset(
|
|
self.dataset,
|
|
self.weights,
|
|
size_ratio=self.size_ratio,
|
|
batch_by_size=False,
|
|
seed=0,
|
|
)
|
|
|
|
results = self._test_common(resampling_dataset, iters=1000)
|
|
|
|
# For batch_by_size = False, the batches should be returned in
|
|
# arbitrary order of size.
|
|
assert not results["ordered_by_size"]
|
|
|
|
# Allow tolerance in distribution error of 2%.
|
|
assert results["max_distribution_diff"] < 0.02
|
|
|
|
def test_resampling_dataset_batch_by_size_true(self):
|
|
resampling_dataset = ResamplingDataset(
|
|
self.dataset,
|
|
self.weights,
|
|
size_ratio=self.size_ratio,
|
|
batch_by_size=True,
|
|
seed=0,
|
|
)
|
|
|
|
results = self._test_common(resampling_dataset, iters=1000)
|
|
|
|
# For batch_by_size = True, the batches should be returned in
|
|
# increasing order of size.
|
|
assert results["ordered_by_size"]
|
|
|
|
# Allow tolerance in distribution error of 2%.
|
|
assert results["max_distribution_diff"] < 0.02
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|