mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
fix mismatch length of counting iterator when truncated
Summary:
PySpeech integration training tests have recently been stuck at end of epoch.
Digging into it, it looks like this is because the end of epoch check relies on this (https://fburl.com/diffusion/xt09z6n9):
```
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
```
which is implemented like this in CountingIterator:
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.n < len(self)
It seems like D23172408 (110f9f0cc7
) modified CountingIterator such that `len(self) > len(iter(self))` when `take()` is used. This mismatch causes `has_next` to return `True` for some PySpeech processes even when all elements in `iter(self))` have been consumed, causing training to get stuck.
My proposed fix is to remove the `self.early_stop` variable and just directly modify `self.total` and `self.iterable`, ensuring `len(self) == len(iter(self))`
Reviewed By: myleott
Differential Revision: D23250734
fbshipit-source-id: efb5a38216783bded67f501135b2f68b9246b9dd
This commit is contained in:
parent
83d701ac10
commit
49940c8d25
@ -53,8 +53,6 @@ class CountingIterator(object):
|
||||
else:
|
||||
self.total = total
|
||||
|
||||
self.early_stop = self.total
|
||||
|
||||
def __len__(self):
|
||||
return self.total
|
||||
|
||||
@ -65,8 +63,6 @@ class CountingIterator(object):
|
||||
'Mismatch between actual and expected iterable length. '
|
||||
'Please report this to the fairseq developers.'
|
||||
)
|
||||
elif self.n >= self.early_stop:
|
||||
return # early stop based on take()
|
||||
self.n += 1
|
||||
yield x
|
||||
|
||||
@ -86,11 +82,13 @@ class CountingIterator(object):
|
||||
"""
|
||||
Truncates the iterator to n elements at most.
|
||||
"""
|
||||
self.early_stop = min(self.early_stop, n)
|
||||
self.total = min(self.total, n)
|
||||
|
||||
# Propagate this change to the underlying iterator
|
||||
if hasattr(self.iterable, "take"):
|
||||
self.iterable.take(n)
|
||||
else:
|
||||
self.iterable = itertools.islice(self.iterable, n)
|
||||
|
||||
|
||||
class EpochBatchIterating(object):
|
||||
|
@ -70,6 +70,21 @@ class TestIterators(unittest.TestCase):
|
||||
itr = iterators.ShardedIterator(x, num_shards=3, shard_id=0)
|
||||
self.test_counting_iterator(ref, itr)
|
||||
|
||||
def test_counting_iterator_take(self):
|
||||
ref = list(range(10))
|
||||
itr = iterators.CountingIterator(ref)
|
||||
itr.take(5)
|
||||
self.assertEqual(len(itr), len(list(iter(itr))))
|
||||
self.assertEqual(len(itr), 5)
|
||||
|
||||
itr = iterators.CountingIterator(ref)
|
||||
itr.take(5)
|
||||
self.assertEqual(next(itr), ref[0])
|
||||
self.assertEqual(next(itr), ref[1])
|
||||
itr.skip(2)
|
||||
self.assertEqual(next(itr), ref[4])
|
||||
self.assertFalse(itr.has_next())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user