mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-08-16 12:00:25 +03:00
turn persistent workers off by default (#4524)
This commit is contained in:
parent
ba415c99ca
commit
5307a0e078
@ -156,6 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
|
||||
num_workers=0,
|
||||
buffer_size=0,
|
||||
timeout=0,
|
||||
persistent_workers=False,
|
||||
):
|
||||
assert isinstance(dataset, torch.utils.data.IterableDataset)
|
||||
self.dataset = dataset
|
||||
@ -167,6 +168,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
|
||||
# in a shared computing environment.
|
||||
self.buffer_size = min(buffer_size, 20)
|
||||
self.timeout = timeout
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
self._current_epoch_iterator = None
|
||||
|
||||
@ -218,7 +220,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
|
||||
timeout=self.timeout,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.num_workers > 0,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
|
||||
# Wrap with a BufferedIterator if needed
|
||||
@ -319,6 +321,7 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
skip_remainder_batch=False,
|
||||
grouped_shuffling=False,
|
||||
reuse_dataloader=False,
|
||||
persistent_workers=False,
|
||||
):
|
||||
assert isinstance(dataset, torch.utils.data.Dataset)
|
||||
self.dataset = dataset
|
||||
@ -347,6 +350,7 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
|
||||
self.dataloader = None
|
||||
self.reuse_dataloader = reuse_dataloader
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
@property
|
||||
def frozen_batches(self):
|
||||
@ -478,10 +482,10 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
|
||||
):
|
||||
if self.reuse_dataloader and self.dataloader is not None:
|
||||
self.batch_sampler.make_batches_for_epoch(epoch, offset)
|
||||
self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset)
|
||||
itr = self.dataloader
|
||||
else:
|
||||
self.batch_sampler = FrozenBatchSampler(
|
||||
self.epoch_batch_sampler = FrozenBatchSampler(
|
||||
self.ordered_batches,
|
||||
epoch,
|
||||
fix_batches_to_gpus,
|
||||
@ -489,7 +493,7 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
initial_offset=offset,
|
||||
)
|
||||
|
||||
if offset > 0 and len(self.batch_sampler) == 0:
|
||||
if offset > 0 and len(self.epoch_batch_sampler) == 0:
|
||||
return None
|
||||
|
||||
if self.num_workers > 0:
|
||||
@ -499,11 +503,11 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
itr = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=self.batch_sampler,
|
||||
batch_sampler=self.epoch_batch_sampler,
|
||||
num_workers=self.num_workers,
|
||||
timeout=self.timeout,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.num_workers > 0,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
|
||||
if self.reuse_dataloader:
|
||||
@ -519,7 +523,8 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
if self.skip_remainder_batch:
|
||||
# TODO: Below is a lazy implementation which discard the final batch regardless
|
||||
# of whether it is a full batch or not.
|
||||
total_num_itrs = len(self.batch_sampler) - 1
|
||||
|
||||
total_num_itrs = len(self.epoch_batch_sampler) - 1
|
||||
itr.take(total_num_itrs)
|
||||
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
|
||||
|
||||
@ -772,6 +777,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
|
||||
mult_rate=1,
|
||||
buffer_size=0,
|
||||
skip_remainder_batch=False,
|
||||
reuse_dataloader=False,
|
||||
persistent_workers=False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset,
|
||||
@ -784,6 +791,8 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
|
||||
epoch,
|
||||
buffer_size,
|
||||
skip_remainder_batch=skip_remainder_batch,
|
||||
reuse_dataloader=reuse_dataloader,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
# level 0: sub-samplers 1: batch_idx 2: batches
|
||||
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
|
||||
@ -866,7 +875,7 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=batches[offset:],
|
||||
num_workers=self.num_workers,
|
||||
persistent_workers=self.num_workers > 0,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
if self.buffer_size > 0:
|
||||
itr = BufferedIterator(self.buffer_size, itr)
|
||||
|
@ -300,6 +300,7 @@ class FairseqTask(object):
|
||||
)
|
||||
|
||||
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
|
||||
persistent_workers = getattr(self.cfg, "persistent_workers", False)
|
||||
|
||||
# return a reusable, sharded iterator
|
||||
epoch_iter = iterators.EpochBatchIterator(
|
||||
@ -315,6 +316,7 @@ class FairseqTask(object):
|
||||
skip_remainder_batch=skip_remainder_batch,
|
||||
grouped_shuffling=grouped_shuffling,
|
||||
reuse_dataloader=reuse_dataloader,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
if can_reuse_epoch_itr:
|
||||
|
Loading…
Reference in New Issue
Block a user