turn persistent workers off by default (#4524)

This commit is contained in:
Alexei Baevski 2022-06-30 11:48:49 -07:00 committed by GitHub
parent ba415c99ca
commit 5307a0e078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 8 deletions

View File

@ -156,6 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
num_workers=0, num_workers=0,
buffer_size=0, buffer_size=0,
timeout=0, timeout=0,
persistent_workers=False,
): ):
assert isinstance(dataset, torch.utils.data.IterableDataset) assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset self.dataset = dataset
@ -167,6 +168,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
# in a shared computing environment. # in a shared computing environment.
self.buffer_size = min(buffer_size, 20) self.buffer_size = min(buffer_size, 20)
self.timeout = timeout self.timeout = timeout
self.persistent_workers = persistent_workers
self._current_epoch_iterator = None self._current_epoch_iterator = None
@ -218,7 +220,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
timeout=self.timeout, timeout=self.timeout,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
pin_memory=True, pin_memory=True,
persistent_workers=self.num_workers > 0, persistent_workers=self.persistent_workers,
) )
# Wrap with a BufferedIterator if needed # Wrap with a BufferedIterator if needed
@ -319,6 +321,7 @@ class EpochBatchIterator(EpochBatchIterating):
skip_remainder_batch=False, skip_remainder_batch=False,
grouped_shuffling=False, grouped_shuffling=False,
reuse_dataloader=False, reuse_dataloader=False,
persistent_workers=False,
): ):
assert isinstance(dataset, torch.utils.data.Dataset) assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset self.dataset = dataset
@ -347,6 +350,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.dataloader = None self.dataloader = None
self.reuse_dataloader = reuse_dataloader self.reuse_dataloader = reuse_dataloader
self.persistent_workers = persistent_workers
@property @property
def frozen_batches(self): def frozen_batches(self):
@ -478,10 +482,10 @@ class EpochBatchIterator(EpochBatchIterating):
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
): ):
if self.reuse_dataloader and self.dataloader is not None: if self.reuse_dataloader and self.dataloader is not None:
self.batch_sampler.make_batches_for_epoch(epoch, offset) self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset)
itr = self.dataloader itr = self.dataloader
else: else:
self.batch_sampler = FrozenBatchSampler( self.epoch_batch_sampler = FrozenBatchSampler(
self.ordered_batches, self.ordered_batches,
epoch, epoch,
fix_batches_to_gpus, fix_batches_to_gpus,
@ -489,7 +493,7 @@ class EpochBatchIterator(EpochBatchIterating):
initial_offset=offset, initial_offset=offset,
) )
if offset > 0 and len(self.batch_sampler) == 0: if offset > 0 and len(self.epoch_batch_sampler) == 0:
return None return None
if self.num_workers > 0: if self.num_workers > 0:
@ -499,11 +503,11 @@ class EpochBatchIterator(EpochBatchIterating):
itr = torch.utils.data.DataLoader( itr = torch.utils.data.DataLoader(
self.dataset, self.dataset,
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
batch_sampler=self.batch_sampler, batch_sampler=self.epoch_batch_sampler,
num_workers=self.num_workers, num_workers=self.num_workers,
timeout=self.timeout, timeout=self.timeout,
pin_memory=True, pin_memory=True,
persistent_workers=self.num_workers > 0, persistent_workers=self.persistent_workers,
) )
if self.reuse_dataloader: if self.reuse_dataloader:
@ -519,7 +523,8 @@ class EpochBatchIterator(EpochBatchIterating):
if self.skip_remainder_batch: if self.skip_remainder_batch:
# TODO: Below is a lazy implementation which discard the final batch regardless # TODO: Below is a lazy implementation which discard the final batch regardless
# of whether it is a full batch or not. # of whether it is a full batch or not.
total_num_itrs = len(self.batch_sampler) - 1
total_num_itrs = len(self.epoch_batch_sampler) - 1
itr.take(total_num_itrs) itr.take(total_num_itrs)
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
@ -772,6 +777,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
mult_rate=1, mult_rate=1,
buffer_size=0, buffer_size=0,
skip_remainder_batch=False, skip_remainder_batch=False,
reuse_dataloader=False,
persistent_workers=False,
): ):
super().__init__( super().__init__(
dataset, dataset,
@ -784,6 +791,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
epoch, epoch,
buffer_size, buffer_size,
skip_remainder_batch=skip_remainder_batch, skip_remainder_batch=skip_remainder_batch,
reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
) )
# level 0: sub-samplers 1: batch_idx 2: batches # level 0: sub-samplers 1: batch_idx 2: batches
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
@ -866,7 +875,7 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
batch_sampler=batches[offset:], batch_sampler=batches[offset:],
num_workers=self.num_workers, num_workers=self.num_workers,
persistent_workers=self.num_workers > 0, persistent_workers=self.persistent_workers,
) )
if self.buffer_size > 0: if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr) itr = BufferedIterator(self.buffer_size, itr)

View File

@ -300,6 +300,7 @@ class FairseqTask(object):
) )
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
persistent_workers = getattr(self.cfg, "persistent_workers", False)
# return a reusable, sharded iterator # return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator( epoch_iter = iterators.EpochBatchIterator(
@ -315,6 +316,7 @@ class FairseqTask(object):
skip_remainder_batch=skip_remainder_batch, skip_remainder_batch=skip_remainder_batch,
grouped_shuffling=grouped_shuffling, grouped_shuffling=grouped_shuffling,
reuse_dataloader=reuse_dataloader, reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
) )
if can_reuse_epoch_itr: if can_reuse_epoch_itr: