Improve init speed of TokenBlockDataset and EpochBatchIterator

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/704

Differential Revision: D15221549

Pulled By: myleott

fbshipit-source-id: b0021acdc2d7792ce51421f1432e1f2bd8218f7b
This commit is contained in:
Myle Ott 2019-05-07 07:05:49 -07:00 committed by Facebook Github Bot
parent 8d9063fe0d
commit e4edf27a97
4 changed files with 91 additions and 58 deletions

View File

@ -26,11 +26,11 @@ class CountingIterator(object):
count (int): number of elements consumed from this iterator
"""
def __init__(self, iterable):
def __init__(self, iterable, start=0):
self.iterable = iterable
self.count = 0
self.count = start
self.itr = iter(self)
self.len = len(iterable)
self.len = start + len(iterable)
def __len__(self):
return self.len
@ -50,7 +50,6 @@ class CountingIterator(object):
def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
self.len -= num_to_skip
return self
@ -149,11 +148,13 @@ class EpochBatchIterator(object):
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle=state_dict.get('shuffle', True),
offset=itr_pos,
)
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False):
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get
@ -169,25 +170,33 @@ class EpochBatchIterator(object):
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]))
batches, self.num_shards, self.shard_id, fill_value=[]
))
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]
))
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
num_workers=self.num_workers,
))
if offset > 0 and offset >= len(batches):
return None
return CountingIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
),
start=offset,
)
class GroupedIterator(object):

View File

@ -67,38 +67,50 @@ class TokenBlockDataset(FairseqDataset):
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
self.slice_indices = np.empty((len(sizes), 2), dtype=int)
curr = 0
for i, sz in enumerate(sizes):
self.slice_indices[i] = (curr, curr + sz)
curr += sz
if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
self.slice_indices[0, 1] = sizes[0]
self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
else:
raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int)
self.sizes = self.slice_indices[:, 1] - self.slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
if break_mode == 'eos':
# much faster version for eos break mode
self.block_to_dataset_index = np.stack(
[
np.arange(len(sizes)), # starting index in dataset
np.zeros(len(sizes), dtype=np.long), # starting offset within starting index
np.arange(len(sizes)) # ending index in dataset
],
1,
)
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
else:
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
)
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]

View File

@ -23,9 +23,9 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_eos_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([1]), # this should be filtered
torch.LongTensor([8, 7, 6, 1]),
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.tensor([1], dtype=torch.long),
torch.tensor([8, 7, 6, 1], dtype=torch.long),
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
@ -33,9 +33,9 @@ class TestTokenBlockDataset(unittest.TestCase):
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([1]), # this should be filtered
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([1], dtype=torch.long),
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
@ -44,9 +44,9 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_block_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([9, 1], dtype=torch.long),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
self.assertEqual(ds[0].tolist(), [5, 4, 3])
@ -56,19 +56,19 @@ class TestTokenBlockDataset(unittest.TestCase):
def test_complete_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
torch.tensor([8, 7, 6, 1], dtype=torch.long),
torch.tensor([9, 1], dtype=torch.long),
]
ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
data = [
torch.LongTensor([4, 3, 2, 1]),
torch.LongTensor([5, 1]),
torch.LongTensor([1]),
torch.LongTensor([6, 1]),
torch.tensor([4, 3, 2, 1], dtype=torch.long),
torch.tensor([5, 1], dtype=torch.long),
torch.tensor([1], dtype=torch.long),
torch.tensor([6, 1], dtype=torch.long),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])

View File

@ -85,6 +85,18 @@ class TestLoadCheckpoint(unittest.TestCase):
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
for _ in range(150 - 52):
next(itr)
self.assertEqual(epoch_itr.iterations_in_epoch, 149)
self.assertTrue(itr.has_next())
next(itr)
self.assertFalse(itr.has_next())
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertTrue(itr.has_next())
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)