fix mismatch length of counting iterator when truncated

Summary:
PySpeech integration training tests have recently been stuck at end of epoch.

Digging into it, it looks like this is because the end of epoch check relies on this (https://fburl.com/diffusion/xt09z6n9):

```
def end_of_epoch(self) -> bool:
     """Returns whether the most recent epoch iterator has been exhausted"""
     return not self._cur_epoch_itr.has_next()
```

which is implemented like this in CountingIterator:

    def has_next(self):
        """Whether the iterator has been exhausted."""
        return self.n < len(self)

It seems like D23172408 (110f9f0cc7) modified CountingIterator such that `len(self) > len(iter(self))` when `take()` is used. This mismatch causes `has_next` to return `True` for some  PySpeech processes even when all elements in `iter(self))` have been consumed, causing training to get stuck.

My proposed fix is to remove the `self.early_stop`  variable and just directly modify `self.total` and `self.iterable`, ensuring `len(self) == len(iter(self))`

Reviewed By: myleott

Differential Revision: D23250734

fbshipit-source-id: efb5a38216783bded67f501135b2f68b9246b9dd
This commit is contained in:
Alex Xiao 2020-08-20 20:07:45 -07:00 committed by Facebook GitHub Bot
parent 83d701ac10
commit 49940c8d25
2 changed files with 18 additions and 5 deletions

View File

@ -53,8 +53,6 @@ class CountingIterator(object):
else:
self.total = total
self.early_stop = self.total
def __len__(self):
return self.total
@ -65,8 +63,6 @@ class CountingIterator(object):
'Mismatch between actual and expected iterable length. '
'Please report this to the fairseq developers.'
)
elif self.n >= self.early_stop:
return # early stop based on take()
self.n += 1
yield x
@ -86,11 +82,13 @@ class CountingIterator(object):
"""
Truncates the iterator to n elements at most.
"""
self.early_stop = min(self.early_stop, n)
self.total = min(self.total, n)
# Propagate this change to the underlying iterator
if hasattr(self.iterable, "take"):
self.iterable.take(n)
else:
self.iterable = itertools.islice(self.iterable, n)
class EpochBatchIterating(object):

View File

@ -70,6 +70,21 @@ class TestIterators(unittest.TestCase):
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
self.test_counting_iterator(ref, itr)
def test_counting_iterator_take(self):
ref = list(range(10))
itr = iterators.CountingIterator(ref)
itr.take(5)
self.assertEqual(len(itr), len(list(iter(itr))))
self.assertEqual(len(itr), 5)
itr = iterators.CountingIterator(ref)
itr.take(5)
self.assertEqual(next(itr), ref[0])
self.assertEqual(next(itr), ref[1])
itr.skip(2)
self.assertEqual(next(itr), ref[4])
self.assertFalse(itr.has_next())
if __name__ == '__main__':
unittest.main()