diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 24e7475b3..45a8c65fa 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -156,6 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating): num_workers=0, buffer_size=0, timeout=0, + persistent_workers=False, ): assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset @@ -167,6 +168,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating): # in a shared computing environment. self.buffer_size = min(buffer_size, 20) self.timeout = timeout + self.persistent_workers = persistent_workers self._current_epoch_iterator = None @@ -218,7 +220,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating): timeout=self.timeout, worker_init_fn=worker_init_fn, pin_memory=True, - persistent_workers=self.num_workers > 0, + persistent_workers=self.persistent_workers, ) # Wrap with a BufferedIterator if needed @@ -319,6 +321,7 @@ class EpochBatchIterator(EpochBatchIterating): skip_remainder_batch=False, grouped_shuffling=False, reuse_dataloader=False, + persistent_workers=False, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -347,6 +350,7 @@ class EpochBatchIterator(EpochBatchIterating): self.dataloader = None self.reuse_dataloader = reuse_dataloader + self.persistent_workers = persistent_workers @property def frozen_batches(self): @@ -478,10 +482,10 @@ class EpochBatchIterator(EpochBatchIterating): self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 ): if self.reuse_dataloader and self.dataloader is not None: - self.batch_sampler.make_batches_for_epoch(epoch, offset) + self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset) itr = self.dataloader else: - self.batch_sampler = FrozenBatchSampler( + self.epoch_batch_sampler = FrozenBatchSampler( self.ordered_batches, epoch, fix_batches_to_gpus, @@ -489,7 +493,7 @@ class EpochBatchIterator(EpochBatchIterating): initial_offset=offset, ) - if offset > 0 and len(self.batch_sampler) == 0: + if offset > 0 and len(self.epoch_batch_sampler) == 0: return None if self.num_workers > 0: @@ -499,11 +503,11 @@ class EpochBatchIterator(EpochBatchIterating): itr = torch.utils.data.DataLoader( self.dataset, collate_fn=self.collate_fn, - batch_sampler=self.batch_sampler, + batch_sampler=self.epoch_batch_sampler, num_workers=self.num_workers, timeout=self.timeout, pin_memory=True, - persistent_workers=self.num_workers > 0, + persistent_workers=self.persistent_workers, ) if self.reuse_dataloader: @@ -519,7 +523,8 @@ class EpochBatchIterator(EpochBatchIterating): if self.skip_remainder_batch: # TODO: Below is a lazy implementation which discard the final batch regardless # of whether it is a full batch or not. - total_num_itrs = len(self.batch_sampler) - 1 + + total_num_itrs = len(self.epoch_batch_sampler) - 1 itr.take(total_num_itrs) logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") @@ -772,6 +777,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator): mult_rate=1, buffer_size=0, skip_remainder_batch=False, + reuse_dataloader=False, + persistent_workers=False, ): super().__init__( dataset, @@ -784,6 +791,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator): epoch, buffer_size, skip_remainder_batch=skip_remainder_batch, + reuse_dataloader=reuse_dataloader, + persistent_workers=persistent_workers, ) # level 0: sub-samplers 1: batch_idx 2: batches self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) @@ -866,7 +875,7 @@ class GroupedEpochBatchIterator(EpochBatchIterator): collate_fn=self.collate_fn, batch_sampler=batches[offset:], num_workers=self.num_workers, - persistent_workers=self.num_workers > 0, + persistent_workers=self.persistent_workers, ) if self.buffer_size > 0: itr = BufferedIterator(self.buffer_size, itr) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 0d1dcf690..3cba8f224 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -300,6 +300,7 @@ class FairseqTask(object): ) reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) + persistent_workers = getattr(self.cfg, "persistent_workers", False) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( @@ -315,6 +316,7 @@ class FairseqTask(object): skip_remainder_batch=skip_remainder_batch, grouped_shuffling=grouped_shuffling, reuse_dataloader=reuse_dataloader, + persistent_workers=persistent_workers, ) if can_reuse_epoch_itr: