fairseq/tests/test_multi_corpus_dataset.py
Alex Xiao fc2840de58 optimize sampling process of multi_corpus_dataset
Summary:
The sampling process in multi_corpus_dataset is very inefficient. Turns out we can signficantly optimize it by sampling in batches rather than one by one. this allows:

1. fast local development and iteration with corpus sampling, as the turnaround time was long before
2. makes it take less time for our jobs can start training, enabling earlier signal if for example there is a configuration issue

Reviewed By: zhengwy888

Differential Revision: D26187821

fbshipit-source-id: b4f7f6b7c187b3785499308226e2af671a6c354f
2021-03-03 19:31:40 -08:00

80 lines
2.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import OrderedDict
import torch
from fairseq.data import LanguagePairDataset, TokenBlockDataset
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
from tests.test_train import mock_dict
class TestMultiCorpusDataset(unittest.TestCase):
def setUp(self):
d = mock_dict()
tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1)
tokens_ds1 = TokenBlockDataset(
tokens_1,
sizes=[tokens_1.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_1 = LanguagePairDataset(
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
)
tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1)
tokens_ds2 = TokenBlockDataset(
tokens_2,
sizes=[tokens_2.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_2 = LanguagePairDataset(
tokens_ds2, tokens_ds2.sizes, d, shuffle=False
)
def _test_sample_helper(
self,
distribution,
):
m = MultiCorpusDataset(
OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
distribution=distribution,
seed=0,
sort_indices=True,
)
m.set_epoch(1)
indices = m.ordered_indices()
count_sample_from_first_dataset = 0
items = set()
for i in indices:
item = m[i]["source"].item()
if item % 2 == 1:
count_sample_from_first_dataset += 1
items.add(item)
sample_from_first_ds_percentage = (
1.0 * count_sample_from_first_dataset / len(indices)
)
self.assertLess(
abs(sample_from_first_ds_percentage - distribution[0]),
0.01,
)
self.assertEqual(
len(items),
int(min(len(self.dataset_1), len(indices) * distribution[0])
+ min(len(self.dataset_1), len(indices) * distribution[1]))
)
print(distribution)
def test_multi_corpus_dataset(self):
for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]:
self._test_sample_helper(distribution=distribution)