fixed train valid epoch iter

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/866

Differential Revision: D17517115

fbshipit-source-id: fd6921e642c99e37fce6ad58b24c93e70a5364e5
This commit is contained in:
Naman Goyal 2019-09-23 11:32:20 -07:00 committed by Facebook Github Bot
parent 10f9349e8a
commit 3b09b98b66
2 changed files with 7 additions and 7 deletions

View File

@ -24,7 +24,7 @@ class FairseqTask(object):
def __init__(self, args):
self.args = args
self.datasets = {}
self.epoch_iter = None
self.dataset_to_epoch_iter = {}
@classmethod
def load_dictionary(cls, filename):
@ -120,7 +120,6 @@ class FairseqTask(object):
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
@ -128,8 +127,8 @@ class FairseqTask(object):
# For default fairseq task, return same iterator across epochs
# as datasets are not dynamic, can be overridden in task specific
# setting.
if self.epoch_iter is not None:
return self.epoch_iter
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset)
@ -153,7 +152,7 @@ class FairseqTask(object):
)
# return a reusable, sharded iterator
self.epoch_iter = iterators.EpochBatchIterator(
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
@ -163,7 +162,8 @@ class FairseqTask(object):
num_workers=num_workers,
epoch=epoch,
)
return self.epoch_iter
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_model(self, args):
"""

View File

@ -286,7 +286,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
):
# Recreate epoch iterator every epoch cause the underlying
# datasets are dynamic due to sampling.
self.epoch_iter = None
self.dataset_to_epoch_iter = None
return super().get_batch_iterator(
dataset, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, required_batch_size_multiple,