turn persistent workers off by default (#4524)

This commit is contained in:
Alexei Baevski 2022-06-30 11:48:49 -07:00 committed by GitHub
parent ba415c99ca
commit 5307a0e078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 8 deletions

View File

@ -156,6 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
num_workers=0,
buffer_size=0,
timeout=0,
persistent_workers=False,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
@ -167,6 +168,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
# in a shared computing environment.
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
self.persistent_workers = persistent_workers
self._current_epoch_iterator = None
@ -218,7 +220,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
timeout=self.timeout,
worker_init_fn=worker_init_fn,
pin_memory=True,
persistent_workers=self.num_workers > 0,
persistent_workers=self.persistent_workers,
)
# Wrap with a BufferedIterator if needed
@ -319,6 +321,7 @@ class EpochBatchIterator(EpochBatchIterating):
skip_remainder_batch=False,
grouped_shuffling=False,
reuse_dataloader=False,
persistent_workers=False,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
@ -347,6 +350,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.dataloader = None
self.reuse_dataloader = reuse_dataloader
self.persistent_workers = persistent_workers
@property
def frozen_batches(self):
@ -478,10 +482,10 @@ class EpochBatchIterator(EpochBatchIterating):
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
):
if self.reuse_dataloader and self.dataloader is not None:
self.batch_sampler.make_batches_for_epoch(epoch, offset)
self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset)
itr = self.dataloader
else:
self.batch_sampler = FrozenBatchSampler(
self.epoch_batch_sampler = FrozenBatchSampler(
self.ordered_batches,
epoch,
fix_batches_to_gpus,
@ -489,7 +493,7 @@ class EpochBatchIterator(EpochBatchIterating):
initial_offset=offset,
)
if offset > 0 and len(self.batch_sampler) == 0:
if offset > 0 and len(self.epoch_batch_sampler) == 0:
return None
if self.num_workers > 0:
@ -499,11 +503,11 @@ class EpochBatchIterator(EpochBatchIterating):
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=self.batch_sampler,
batch_sampler=self.epoch_batch_sampler,
num_workers=self.num_workers,
timeout=self.timeout,
pin_memory=True,
persistent_workers=self.num_workers > 0,
persistent_workers=self.persistent_workers,
)
if self.reuse_dataloader:
@ -519,7 +523,8 @@ class EpochBatchIterator(EpochBatchIterating):
if self.skip_remainder_batch:
# TODO: Below is a lazy implementation which discard the final batch regardless
# of whether it is a full batch or not.
total_num_itrs = len(self.batch_sampler) - 1
total_num_itrs = len(self.epoch_batch_sampler) - 1
itr.take(total_num_itrs)
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
@ -772,6 +777,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
mult_rate=1,
buffer_size=0,
skip_remainder_batch=False,
reuse_dataloader=False,
persistent_workers=False,
):
super().__init__(
dataset,
@ -784,6 +791,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
epoch,
buffer_size,
skip_remainder_batch=skip_remainder_batch,
reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
)
# level 0: sub-samplers 1: batch_idx 2: batches
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
@ -866,7 +875,7 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
persistent_workers=self.num_workers > 0,
persistent_workers=self.persistent_workers,
)
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)

View File

@ -300,6 +300,7 @@ class FairseqTask(object):
)
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
persistent_workers = getattr(self.cfg, "persistent_workers", False)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
@ -315,6 +316,7 @@ class FairseqTask(object):
skip_remainder_batch=skip_remainder_batch,
grouped_shuffling=grouped_shuffling,
reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
)
if can_reuse_epoch_itr: