Use 1-based indexing for epochs everywhere (#1053)

Summary:
We are somewhat inconsistent in whether we're using 0-based or 1-based indexing for epochs. This should fix things to be 0-based internally, with logging and checkpoint naming still using 1-based indexing.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1053

Reviewed By: spencerp

Differential Revision: D20160715

Pulled By: myleott

fbshipit-source-id: 4ed94f9c371e1bfe29bcfa087fa6756507d6e627
This commit is contained in:
Myle Ott 2020-03-04 16:34:53 -08:00 committed by Facebook Github Bot
parent 4171b83cfd
commit aa79bb9c37
26 changed files with 63 additions and 116 deletions

View File

@ -66,7 +66,7 @@ class CommonsenseQATask(FairseqTask):
return cls(args, vocab)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:

View File

@ -101,7 +101,7 @@ class WSCTask(FairseqTask):
mask[mask_start:mask_start + mask_size] = 1
return toks, mask
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
@ -281,7 +281,7 @@ class WinograndeTask(WSCTask):
return cls(args, vocab)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:

View File

@ -42,7 +42,7 @@ class DummyLMTask(FairseqTask):
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)

View File

@ -53,7 +53,7 @@ class DummyMaskedLMTask(FairseqTask):
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)

View File

@ -109,6 +109,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
@ -150,7 +151,7 @@ def load_checkpoint(args, trainer, **passthrough_args):
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=0, load_dataset=True, **passthrough_args
epoch=1, load_dataset=True, **passthrough_args
)
trainer.lr_step(epoch_itr.epoch)
@ -349,6 +350,11 @@ def _upgrade_state_dict(state):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
state["extra_state"]["train_iterator"]["epoch"] = max(
getattr(state["extra_state"]["train_iterator"], "epoch", 1),
1,
)
# set any missing default values in the task, model or other registries
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])

View File

@ -38,7 +38,6 @@ from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
@ -96,7 +95,6 @@ __all__ = [
'ResamplingDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',
'ShardedIterator',
'SortDataset',
'StripTokenDataset',

View File

@ -100,17 +100,18 @@ class EpochBatchIterating(object):
class StreamingEpochBatchIterator(EpochBatchIterating):
def __init__(
self, dataset, epoch=0, num_shards=1, shard_id=0,
self, dataset, epoch=1, num_shards=1, shard_id=0,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
self.epoch = epoch
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self._current_epoch_iterator = None
self.num_shards = num_shards
self.shard_id = shard_id
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
self.epoch += 1
if self._current_epoch_iterator is not None and self.end_of_epoch():
self.epoch += 1
self.dataset.set_epoch(self.epoch)
self._current_epoch_iterator = CountingIterator(
iterable=ShardedIterator(
@ -165,12 +166,12 @@ class EpochBatchIterator(EpochBatchIterating):
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
(default: 1).
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0, epoch=0,
num_workers=0, epoch=1,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
@ -181,7 +182,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.shard_id = shard_id
self.num_workers = num_workers
self.epoch = epoch
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = True
self._cur_epoch_itr = None
self._next_epoch_itr = None
@ -204,7 +205,8 @@ class EpochBatchIterator(EpochBatchIterating):
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
if self._cur_epoch_itr is not None and self.end_of_epoch():
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
@ -244,6 +246,9 @@ class EpochBatchIterator(EpochBatchIterating):
shuffle=state_dict.get('shuffle', True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
# we finished the epoch, increment epoch counter
self.epoch += 1
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):

View File

@ -31,7 +31,7 @@ class ResamplingDataset(BaseWrapperDataset):
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 0).
epoch (int): starting epoch number (default: 1).
"""
def __init__(
@ -42,7 +42,7 @@ class ResamplingDataset(BaseWrapperDataset):
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=0,
epoch=1,
):
super().__init__(dataset)

View File

@ -1,60 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import os
import random
from . import BaseWrapperDataset
from fairseq.data import data_utils
class ShardedDataset(BaseWrapperDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Loads a dataset which has been sharded into multiple files. each shard is only loaded for each specific epoch
"""
def __init__(
self,
dictionary,
dataset_impl: str,
path: str,
split: str,
epoch: int,
name: str = None,
combine: bool = False,
seed: int = 0,
):
self._name = name if name is not None else os.path.basename(path)
num_shards = 0
for i in itertools.count():
if not os.path.exists(os.path.join(path, "shard" + str(i))):
break
num_shards += 1
if num_shards > 0 and split == "train":
random.seed(seed ^ epoch)
shard = random.randint(0, num_shards - 1)
split_path = os.path.join(path, "shard" + str(shard), split)
else:
split_path = os.path.join(path, split)
if os.path.isdir(split_path):
split_path = os.path.join(split_path, split)
dataset = data_utils.load_indexed_dataset(
split_path, dictionary, dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
super().__init__(dataset)
@property
def name(self):
return self._name

View File

@ -102,7 +102,7 @@ class CrossLingualLMTask(FairseqTask):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
@ -136,8 +136,9 @@ class CrossLingualLMTask(FairseqTask):
return dataset, sizes
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
@ -165,5 +166,5 @@ class CrossLingualLMTask(FairseqTask):
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
logger.info('{} {} {} examples'.format(
utils.split_paths(self.args.data)[epoch], split, len(self.datasets[split]))
utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split]))
)

View File

@ -104,16 +104,15 @@ class DenoisingTask(FairseqTask):
args.shuffle_instance = False
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(

View File

@ -116,7 +116,7 @@ class FairseqTask(object):
num_shards=1,
shard_id=0,
num_workers=0,
epoch=0,
epoch=1,
):
"""
Get an iterator that yields batches of data from the given dataset.
@ -143,7 +143,7 @@ class FairseqTask(object):
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
(default: 1).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split

View File

@ -148,7 +148,7 @@ class LanguageModelingTask(FairseqTask):
return model
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -157,7 +157,7 @@ class LanguageModelingTask(FairseqTask):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(

View File

@ -78,8 +78,9 @@ class LegacyMaskedLMTask(FairseqTask):
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False):
def load_dataset(self, split, epoch=1, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
@ -87,7 +88,7 @@ class LegacyMaskedLMTask(FairseqTask):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
logger.info("data_path", data_path)
for k in itertools.count():

View File

@ -75,7 +75,7 @@ class MaskedLMTask(FairseqTask):
logger.info('dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -83,7 +83,7 @@ class MaskedLMTask(FairseqTask):
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(

View File

@ -87,15 +87,15 @@ class MultilingualDenoisingTask(DenoisingTask):
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
if self.langs is None:

View File

@ -116,7 +116,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -124,7 +124,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
languages = sorted(
name for name in os.listdir(data_path)
@ -295,7 +295,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
):
# Recreate epoch iterator every epoch cause the underlying
# datasets are dynamic due to sampling.

View File

@ -187,12 +187,11 @@ class MultilingualTranslationTask(FairseqTask):
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, epoch=0, **kwargs):
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')

View File

@ -136,12 +136,11 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training)
def load_dataset(self, split, epoch=0, **kwargs):
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
def split_exists(split, src, tgt, lang):
if src is not None:

View File

@ -232,7 +232,7 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -240,7 +240,7 @@ class TranslationTask(FairseqTask):
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang

View File

@ -53,7 +53,7 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
d.add_symbol('[{}]'.format(l))
d.add_symbol('<mask>')
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -61,7 +61,7 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang

View File

@ -29,7 +29,7 @@ class TranslationLevenshteinTask(TranslationTask):
default='random_delete',
choices=['random_delete', 'random_mask', 'no_noise', 'full_mask'])
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -37,7 +37,7 @@ class TranslationLevenshteinTask(TranslationTask):
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang

View File

@ -477,7 +477,7 @@ class Trainer(object):
self.optimizer.zero_grad()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
"""Adjust the learning rate at the end of the epoch."""
self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps
return self.lr_step_update()

View File

@ -56,7 +56,7 @@ def main(args, init_distributed=False):
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=0)
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# Build model and criterion
model = task.build_model(args)
@ -90,7 +90,7 @@ def main(args, init_distributed=False):
while (
lr > args.min_lr
and (
epoch_itr.epoch < max_epoch
epoch_itr.epoch <= max_epoch
# allow resuming training from the final checkpoint
or epoch_itr._next_epoch_itr is not None
)
@ -148,7 +148,7 @@ def train(args, trainer, task, epoch_itr):
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
shuffle=(epoch_itr.epoch >= args.curriculum),
shuffle=(epoch_itr.epoch > args.curriculum),
)
update_freq = (
args.update_freq[epoch_itr.epoch - 1]

View File

@ -61,7 +61,7 @@ def main(args, override_args=None):
for subset in args.valid_subset.split(','):
try:
task.load_dataset(subset, combine=False, epoch=0)
task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset)
except KeyError:
raise Exception('Cannot find dataset: ' + subset)

View File

@ -67,7 +67,6 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
@ -110,7 +109,7 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches['os.path.isfile'].return_value = False