Add dataset class for weighted sampling with replacement. (#861)

Summary:
As discussed with Naman earlier today. Weighted sampling with
replacement can be done on a per-epoch basis using `set_epoch()`
functionality, which generates the samples as a function of random seed
and epoch.

Additionally, `FairseqTask` needs to set the starting epoch for the
dataset at the very beginning of iterator construction.

Not yet implemented is the per-epoch iterator construction, which
is necessary to actually regenerate the batches for each epoch.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/861

Differential Revision: D17460687

Pulled By: jma127

fbshipit-source-id: 1c2a54f04ac96b3561c100a6fd66a9fccbe3c658
This commit is contained in:
Jerry Ma 2019-09-19 10:34:23 -07:00 committed by Facebook Github Bot
parent 0eaaf35516
commit a8a85c2676
4 changed files with 236 additions and 0 deletions

View File

@ -41,6 +41,7 @@ from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .truncate_dataset import TruncateDataset
from .resampling_dataset import ResamplingDataset
from .iterators import (
CountingIterator,

View File

@ -0,0 +1,128 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset, plasma_utils
class ResamplingDataset(BaseWrapperDataset):
"""Randomly samples from a given dataset at each epoch.
Sampling is done with or without replacement, depending on the "replace"
parameter.
Optionally, the epoch size can be rescaled. This is potentially desirable
to increase per-epoch coverage of the base dataset (since sampling with
replacement means that many items in the dataset will be left out). In the
case of sampling without replacement, size_ratio should be strictly less
than 1.
Args:
dataset (~torch.utils.data.Dataset): dataset on which to sample.
weights (List[float]): list of probability weights
(default: None, which corresponds to uniform sampling).
replace (bool): sampling mode; True for "with replacement", or False
for "without replacement" (default: True)
size_ratio (float): the ratio to subsample to; must be positive
(default: 1.0).
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 0).
"""
def __init__(
self,
dataset,
weights=None,
replace=True,
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=0,
):
super().__init__(dataset)
if weights is None:
self.weights = None
else:
assert len(weights) == len(dataset)
weights_arr = np.array(weights, dtype=np.float64)
weights_arr /= weights_arr.sum()
self.weights = plasma_utils.PlasmaArray(weights_arr)
self.replace = replace
assert size_ratio > 0.0
if not self.replace:
assert size_ratio < 1.0
self.size_ratio = float(size_ratio)
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
self.batch_by_size = batch_by_size
self.seed = seed
self._cur_epoch = None
self._cur_indices = None
self.set_epoch(epoch)
def __getitem__(self, index):
return self.dataset[self._cur_indices.array[index]]
def __len__(self):
return self.actual_size
@property
def sizes(self):
return self.dataset.sizes[self._cur_indices.array]
def num_tokens(self, index):
return self.dataset.num_tokens(self._cur_indices.array[index])
def size(self, index):
return self.dataset.size(self._cur_indices.array[index])
def ordered_indices(self):
if self.batch_by_size:
order = [
np.arange(len(self)),
self.sizes,
] # No need to handle `self.shuffle == True`
return np.lexsort(order)
else:
return np.arange(len(self))
def prefetch(self, indices):
self.dataset.prefetch(self._cur_indices.array[indices])
def set_epoch(self, epoch):
super().set_epoch(epoch)
if epoch == self._cur_epoch:
return
self._cur_epoch = epoch
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
[
42, # magic number
self.seed % (2 ** 32), # global seed
self._cur_epoch, # epoch index
]
)
self._cur_indices = plasma_utils.PlasmaArray(
rng.choice(
len(self.dataset),
self.actual_size,
replace=self.replace,
p=(None if self.weights is None else self.weights.array),
)
)

View File

@ -126,6 +126,9 @@ class FairseqTask(object):
"""
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()

View File

@ -0,0 +1,104 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import unittest
import numpy as np
from fairseq.data import ListDataset, ResamplingDataset
class TestResamplingDataset(unittest.TestCase):
def setUp(self):
self.strings = ["ab", "c", "def", "ghij"]
self.weights = [4.0, 2.0, 7.0, 1.5]
self.size_ratio = 2
self.dataset = ListDataset(
self.strings, np.array([len(s) for s in self.strings])
)
def _test_common(self, resampling_dataset, iters):
assert len(self.dataset) == len(self.strings) == len(self.weights)
assert len(resampling_dataset) == self.size_ratio * len(self.strings)
results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
totalfreqs = 0
freqs = collections.defaultdict(int)
for epoch_num in range(iters):
resampling_dataset.set_epoch(epoch_num)
indices = resampling_dataset.ordered_indices()
assert len(indices) == len(resampling_dataset)
prev_size = -1
for i in indices:
cur_size = resampling_dataset.size(i)
# Make sure indices map to same sequences within an epoch
assert resampling_dataset[i] == resampling_dataset[i]
# Make sure length of sequence is correct
assert cur_size == len(resampling_dataset[i])
freqs[resampling_dataset[i]] += 1
totalfreqs += 1
if prev_size > cur_size:
results["ordered_by_size"] = False
prev_size = cur_size
assert set(freqs.keys()) == set(self.strings)
for s, weight in zip(self.strings, self.weights):
freq = freqs[s] / totalfreqs
expected_freq = weight / sum(self.weights)
results["max_distribution_diff"] = max(
results["max_distribution_diff"], abs(expected_freq - freq)
)
return results
def test_resampling_dataset_batch_by_size_false(self):
resampling_dataset = ResamplingDataset(
self.dataset,
self.weights,
size_ratio=self.size_ratio,
batch_by_size=False,
seed=0,
)
results = self._test_common(resampling_dataset, iters=1000)
# For batch_by_size = False, the batches should be returned in
# arbitrary order of size.
assert not results["ordered_by_size"]
# Allow tolerance in distribution error of 2%.
assert results["max_distribution_diff"] < 0.02
def test_resampling_dataset_batch_by_size_true(self):
resampling_dataset = ResamplingDataset(
self.dataset,
self.weights,
size_ratio=self.size_ratio,
batch_by_size=True,
seed=0,
)
results = self._test_common(resampling_dataset, iters=1000)
# For batch_by_size = True, the batches should be returned in
# increasing order of size.
assert results["ordered_by_size"]
# Allow tolerance in distribution error of 2%.
assert results["max_distribution_diff"] < 0.02
if __name__ == "__main__":
unittest.main()