optimize sampling process of multi_corpus_dataset

Summary:
The sampling process in multi_corpus_dataset is very inefficient. Turns out we can signficantly optimize it by sampling in batches rather than one by one. this allows:

1. fast local development and iteration with corpus sampling, as the turnaround time was long before
2. makes it take less time for our jobs can start training, enabling earlier signal if for example there is a configuration issue

Reviewed By: zhengwy888

Differential Revision: D26187821

fbshipit-source-id: b4f7f6b7c187b3785499308226e2af671a6c354f
This commit is contained in:
Alex Xiao 2021-03-03 19:29:55 -08:00 committed by Facebook GitHub Bot
parent 1fed7a8426
commit fc2840de58
2 changed files with 63 additions and 34 deletions

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import logging
import time
from collections import OrderedDict
from typing import Dict, List
@ -12,7 +13,6 @@ from fairseq.data import data_utils
from . import FairseqDataset
logger = logging.getLogger(__name__)
@ -49,6 +49,7 @@ class MultiCorpusDataset(FairseqDataset):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert len(datasets) == len(distribution)
assert sum(distribution) == 1
self.datasets = datasets
self.distribution = distribution
self.seed = seed
@ -69,43 +70,61 @@ class MultiCorpusDataset(FairseqDataset):
self.total_num_instances += len(dataset)
def ordered_indices(self):
start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch):
# Used to store the order of indices of each dataset to use
indices = [
np.random.permutation(len(dataset))
for dataset in self.datasets.values()
]
# Keep track of which samples we've used for each dataset
counters = [0 for _ in self.datasets]
sampled_indices = []
num_selected_instances = 0
sampled_indices = [
self._sample(indices, counters) for _ in range(self.total_num_instances)
]
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for i, key in enumerate(self.datasets):
if i < len(self.datasets) - 1:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.dataset_offsets[i + 1]
else:
num_instances = self.total_num_instances - num_selected_instances
high = self.total_num_instances
logger.info(f"sampling {num_instances} from {key} dataset")
num_selected_instances += num_instances
# First, add k copies of the dataset where k = num_instances // len(dataset).
# This ensures an equal distribution of the data points as much as possible.
# For the remaining entries randomly sample them
dataset_size = len(self.datasets[key])
num_copies = num_instances // dataset_size
dataset_indices = (
np.random.permutation(high - self.dataset_offsets[i])
+ self.dataset_offsets[i]
)[: num_instances - num_copies * dataset_size]
if num_copies > 0:
sampled_indices += list(
np.concatenate(
(
np.repeat(
np.arange(self.dataset_offsets[i], high), num_copies
),
dataset_indices,
)
)
)
else:
sampled_indices += list(dataset_indices)
assert (
len(sampled_indices) == self.total_num_instances
), f"{len(sampled_indices)} vs {self.total_num_instances}"
np.random.shuffle(sampled_indices)
if self.sort_indices:
sampled_indices.sort(key=lambda i: self.num_tokens(i))
return np.array(sampled_indices, dtype=np.int64)
def _sample(self, indices, counters):
# First pick dataset
dataset_idx = np.random.choice(len(self.distribution), p=self.distribution)
# Then get dataset internal index
idx = indices[dataset_idx][counters[dataset_idx]]
# Convert to multi-datasets index
idx += self.dataset_offsets[dataset_idx]
counters[dataset_idx] += 1
# Reset if we reach end
if counters[dataset_idx] == len(self.dataset_list[dataset_idx]):
counters[dataset_idx] = 0
indices[dataset_idx] = np.random.permutation(
len(self.dataset_list[dataset_idx])
logger.info(
"multi_corpus_dataset ordered_indices took {}s".format(
time.time() - start
)
)
return idx
return np.array(sampled_indices, dtype=np.int64)
def _map_index(self, index: int):
"""

View File

@ -27,7 +27,7 @@ class TestMultiCorpusDataset(unittest.TestCase):
self.dataset_1 = LanguagePairDataset(
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
)
tokens_2 = torch.LongTensor([i for i in range(2, 5000, 2)]).view(1, -1)
tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1)
tokens_ds2 = TokenBlockDataset(
tokens_2,
sizes=[tokens_2.size(-1)],
@ -53,9 +53,13 @@ class TestMultiCorpusDataset(unittest.TestCase):
m.set_epoch(1)
indices = m.ordered_indices()
count_sample_from_first_dataset = 0
items = set()
for i in indices:
if m[i]["source"].item() % 2 == 1:
item = m[i]["source"].item()
if item % 2 == 1:
count_sample_from_first_dataset += 1
items.add(item)
sample_from_first_ds_percentage = (
1.0 * count_sample_from_first_dataset / len(indices)
)
@ -63,6 +67,12 @@ class TestMultiCorpusDataset(unittest.TestCase):
abs(sample_from_first_ds_percentage - distribution[0]),
0.01,
)
self.assertEqual(
len(items),
int(min(len(self.dataset_1), len(indices) * distribution[0])
+ min(len(self.dataset_1), len(indices) * distribution[1]))
)
print(distribution)
def test_multi_corpus_dataset(self):
for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: