Update iterators to support counting, rename CountingIterator.count -> n and add tests (#1166)

Summary:
A few changes here:
- update GroupedIterator and ShardedIterator to support counting. This will be useful on TPUs, since the TPU dataloading threads may advance faster than we can process them.
- add tests for the above
- in CountingIterator, rename `count` -> `n`. This is needed because `count` is overloaded for iterables (e.g., `list` defines a different `count` method, which is actually a search function).
- in CountingIterator, rename `override_len` -> `total` to be more consistent with other iterators (e.g., tqdm). This functionality was unused previously (it's only needed for TPUs), so the rename is easy.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1166

Reviewed By: ngoyal2707

Differential Revision: D21373525

Pulled By: myleott

fbshipit-source-id: 102f3d50ed1a5163a7d1216ca5a179564a05dfe4
This commit is contained in:
Myle Ott 2020-05-14 13:55:12 -07:00 committed by Facebook GitHub Bot
parent 9a718e2985
commit 803c0a6d11
3 changed files with 123 additions and 65 deletions

View File

@ -5,6 +5,7 @@
import itertools
import math
import operator
import os
import time
import numpy as np
@ -27,31 +28,37 @@ class CountingIterator(object):
Args:
iterable (iterable): iterable to wrap
start (int): starting iteration count
override_len (int): override the iterator length
returned by ``__len__``
start (int): starting iteration count. Note that this doesn't
actually advance the iterator.
total (int): override the iterator length returned by
``__len__``. This can be used to truncate *iterator*.
Attributes:
count (int): number of elements consumed from this iterator
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, start=0, override_len=None):
def __init__(self, iterable, start=None, total=None):
self.iterable = iterable
self.count = start
self.itr = iter(self)
if override_len is None:
self.len = start + len(iterable)
if start is None:
self.n = getattr(iterable, 'n', 0)
else:
self.len = override_len
self.n = start
if total is None:
self.total = self.n + len(iterable)
else:
self.total = total
def __len__(self):
return self.len
return self.total
def __iter__(self):
for x in self.iterable:
if self.count >= self.len:
if self.n >= self.total:
return
self.count += 1
self.n += 1
yield x
def __next__(self):
@ -59,7 +66,7 @@ class CountingIterator(object):
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.count < len(self)
return self.n < len(self)
def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
@ -70,7 +77,7 @@ class CountingIterator(object):
"""
Truncates the iterator to n elements at most.
"""
self.len = min(self.len, n)
self.total = min(self.total, n)
class EpochBatchIterating(object):
@ -148,7 +155,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
@property
def iterations_in_epoch(self) -> int:
if self._current_epoch_iterator is not None:
return self._current_epoch_iterator.count
return self._current_epoch_iterator.n
return 0
def state_dict(self):
@ -213,7 +220,11 @@ class EpochBatchIterator(EpochBatchIterating):
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
def __len__(self):
return len(self.frozen_batches)
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
@property
def n(self):
return self.iterations_in_epoch
@property
def next_epoch_idx(self):
@ -255,9 +266,9 @@ class EpochBatchIterator(EpochBatchIterating):
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
return self._cur_epoch_itr.n
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return self._next_epoch_itr.n
return 0
def state_dict(self):
@ -337,38 +348,39 @@ class EpochBatchIterator(EpochBatchIterating):
return itr
class GroupedIterator(object):
class GroupedIterator(CountingIterator):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.offset = int(math.ceil(getattr(iterable, 'count', 0) / float(chunk_size)))
self.itr = iterable
itr = _chunk_iterator(iterable, chunk_size)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, 'n', 0) / float(chunk_size))),
total=int(math.ceil(len(iterable) / float(chunk_size))),
)
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
def _chunk_iterator(itr, chunk_size):
chunk = []
for x in itr:
chunk.append(x)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if len(chunk) > 0:
yield chunk
class ShardedIterator(object):
class ShardedIterator(CountingIterator):
"""A sharded wrapper around an iterable, padded to length.
Args:
@ -377,30 +389,28 @@ class ShardedIterator(object):
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards* (default: None).
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
itr = map(
operator.itemgetter(1),
itertools.zip_longest(
range(sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
),
)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, 'n', 0) / float(num_shards))),
total=sharded_len,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class BackgroundConsumer(Thread):

View File

@ -106,7 +106,7 @@ class BaseProgressBar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.offset = getattr(iterable, 'offset', 0)
self.n = getattr(iterable, 'n', 0)
self.epoch = epoch
self.prefix = ''
if epoch is not None:
@ -170,7 +170,7 @@ class JsonProgressBar(BaseProgressBar):
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.offset):
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
@ -242,7 +242,7 @@ class SimpleProgressBar(BaseProgressBar):
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.offset):
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj

View File

@ -10,18 +10,66 @@ from fairseq.data import iterators
class TestIterators(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = iterators.CountingIterator(x)
def test_counting_iterator(self, ref=None, itr=None):
if ref is None:
assert itr is None
ref = list(range(10))
itr = iterators.CountingIterator(ref)
else:
assert len(ref) == 10
assert itr is not None
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
self.assertEqual(itr.n, 0)
self.assertEqual(next(itr), ref[0])
self.assertEqual(itr.n, 1)
self.assertEqual(next(itr), ref[1])
self.assertEqual(itr.n, 2)
itr.skip(3)
self.assertEqual(next(itr), 5)
self.assertEqual(itr.n, 5)
self.assertEqual(next(itr), ref[5])
itr.skip(3)
self.assertEqual(next(itr), 9)
self.assertEqual(itr.n, 9)
self.assertEqual(next(itr), ref[9])
self.assertFalse(itr.has_next())
def test_grouped_iterator(self):
# test correctness
x = list(range(10))
itr = iterators.GroupedIterator(x, 1)
self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
itr = iterators.GroupedIterator(x, 4)
self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]])
itr = iterators.GroupedIterator(x, 5)
self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
# test CountingIterator functionality
x = list(range(30))
ref = list(iterators.GroupedIterator(x, 3))
itr = iterators.GroupedIterator(x, 3)
self.test_counting_iterator(ref, itr)
def test_sharded_iterator(self):
# test correctness
x = list(range(10))
itr = iterators.ShardedIterator(x, num_shards=1, shard_id=0)
self.assertEqual(list(itr), x)
itr = iterators.ShardedIterator(x, num_shards=2, shard_id=0)
self.assertEqual(list(itr), [0, 2, 4, 6, 8])
itr = iterators.ShardedIterator(x, num_shards=2, shard_id=1)
self.assertEqual(list(itr), [1, 3, 5, 7, 9])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
self.assertEqual(list(itr), [0, 3, 6, 9])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=1)
self.assertEqual(list(itr), [1, 4, 7, None])
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=2)
self.assertEqual(list(itr), [2, 5, 8, None])
# test CountingIterator functionality
x = list(range(30))
ref = list(iterators.ShardedIterator(x, num_shards=3, shard_id=0))
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
self.test_counting_iterator(ref, itr)
if __name__ == '__main__':
unittest.main()