From 0add50c2e0b5dfaeb0900df08131b0cb87cba273 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Mon, 6 May 2019 14:56:13 -0700 Subject: [PATCH] allowing sharded dataset (#696) Summary: Co-authored-by: myleott Changing `data` to be `str` with colon separated list for loading sharded datasets. This change is useful for loading large datasets that cannot fit into, memory. The large dataset can be sharded and then each shard is loaded in one epoch in roudrobin manner. For example, if there are `5` shards of data and `10` epochs then the shards will be iterated upon `[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]`. myleott We need to look into `translation.py` as it currently already expects a list and then concats the datasets. Pull Request resolved: https://github.com/pytorch/fairseq/pull/696 Differential Revision: D15214049 fbshipit-source-id: 03e43a7b69c7aefada2ca668abf1eac1969fe013 --- fairseq/data/iterators.py | 5 +- fairseq/tasks/cross_lingual_lm.py | 23 +++++--- fairseq/tasks/fairseq_task.py | 4 +- fairseq/tasks/language_modeling.py | 16 ++++-- fairseq/tasks/multilingual_translation.py | 20 ++++--- fairseq/tasks/semisupervised_translation.py | 28 +++++---- fairseq/tasks/translation.py | 61 ++++++++++---------- tests/test_train.py | 10 +++- train.py | 64 ++++++++++++++------- 9 files changed, 140 insertions(+), 91 deletions(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 319374481..629f7b162 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -79,11 +79,12 @@ class EpochBatchIterator(object): num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). + epoch (int, optional): The epoch to start the iterator from. """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, + num_workers=0, epoch=0, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -94,7 +95,7 @@ class EpochBatchIterator(object): self.shard_id = shard_id self.num_workers = num_workers - self.epoch = 0 + self.epoch = epoch self._cur_epoch_itr = None self._next_epoch_itr = None self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index dad60932e..799b1c034 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='path to data directory') + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner') parser.add_argument('--tokens-per-sample', default=512, type=int, help='max number of total tokens over all segments' ' per sample') @@ -106,12 +107,16 @@ class CrossLingualLMTask(FairseqTask): return cls(args, dictionary) - def _load_single_lang_dataset(self, split): + def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') - path = os.path.join(self.args.data, split_k) + path = os.path.join(data_path, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) @@ -124,7 +129,7 @@ class CrossLingualLMTask(FairseqTask): if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # Since we append each block with the classification_token, # we need to effectively create blocks of length @@ -136,7 +141,7 @@ class CrossLingualLMTask(FairseqTask): ) ) - print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) + print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] @@ -147,7 +152,7 @@ class CrossLingualLMTask(FairseqTask): return dataset, sizes - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) @@ -162,7 +167,7 @@ class CrossLingualLMTask(FairseqTask): # Datasets are expected to be in "split.lang" format (Eg: train.en) language_split = '{}.{}'.format(split, lang) - block_dataset, sizes = self._load_single_lang_dataset(split=language_split) + block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch) dataset_map[lang] = MaskedLMDataset( dataset=block_dataset, @@ -182,6 +187,6 @@ class CrossLingualLMTask(FairseqTask): dataset_map, default_key=self.default_key ) print('| {} {} {} examples'.format( - self.args.data, split, len(self.datasets[split]) + self.args.data.split(':')[epoch], split, len(self.datasets[split]) ) - ) \ No newline at end of file + ) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 09e409151..607baf3c0 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -92,7 +92,7 @@ class FairseqTask(object): def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, + seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. @@ -118,6 +118,7 @@ class FairseqTask(object): num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). + epoch (int, optional): The epoch to start the iterator from. Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the @@ -149,6 +150,7 @@ class FairseqTask(object): num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, + epoch=epoch, ) def build_model(self, args): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 222cb75b4..2fe086ca4 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -104,7 +104,9 @@ class LanguageModelingTask(FairseqTask): dictionary = None output_dictionary = None if args.data: - dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) + paths = args.data.split(':') + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) print('| dictionary: {} types'.format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: @@ -136,7 +138,7 @@ class LanguageModelingTask(FairseqTask): return model - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: @@ -145,9 +147,13 @@ class LanguageModelingTask(FairseqTask): loaded_datasets = [] + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') - path = os.path.join(self.args.data, split_k) + path = os.path.join(data_path, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) @@ -160,7 +166,7 @@ class LanguageModelingTask(FairseqTask): if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) loaded_datasets.append( TokenBlockDataset( @@ -170,7 +176,7 @@ class LanguageModelingTask(FairseqTask): ) ) - print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) + print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if not combine: break diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index cf720f5a7..8f5f60e60 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -135,7 +135,9 @@ class MultilingualTranslationTask(FairseqTask): # load dictionaries dicts = OrderedDict() for lang in sorted_langs: - dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang))) + paths = args.data.split(':') + assert len(paths) > 0 + dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) if len(dicts) > 0: assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() @@ -185,11 +187,15 @@ class MultilingualTranslationTask(FairseqTask): new_tgt_bos=new_tgt_bos, ) - def load_dataset(self, split, **kwargs): + def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + def split_exists(split, src, tgt, lang): - filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): @@ -210,17 +216,17 @@ class MultilingualTranslationTask(FairseqTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) - print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) + print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index 34941662d..2dc1cd1e9 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -132,14 +132,18 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) return cls(args, dicts, training) - def load_dataset(self, split, **kwargs): + def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + def split_exists(split, src, tgt, lang): if src is not None: - filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): @@ -162,16 +166,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) - print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) + print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # back translation datasets backtranslate_datasets = {} @@ -179,8 +183,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): - raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data)) - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) + raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, @@ -216,7 +220,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ).collater, ) print('| backtranslate-{}: {} {} {} examples'.format( - tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]), + tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] @@ -227,7 +231,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): _, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): continue - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( @@ -255,7 +259,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): tgt_lang=tgt, ) print('| denoising-{}: {} {} {} examples'.format( - tgt, self.args.data, split, len(noising_datasets[lang_pair]), + tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index cab8b5369..785aa86a5 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -48,7 +48,8 @@ class TranslationTask(FairseqTask): def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off - parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)') + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner') parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', @@ -84,19 +85,17 @@ class TranslationTask(FairseqTask): args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) - # upgrade old checkpoints - if isinstance(args.data, str): - args.data = [args.data] - + paths = args.data.split(':') + assert len(paths) > 0 # find language pair automatically if args.source_lang is None or args.target_lang is None: - args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) + args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0]) if args.source_lang is None or args.target_lang is None: raise Exception('Could not infer language pair, please provide it explicitly') # load dictionaries - src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang))) - tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang))) + src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang))) + tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang))) assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() @@ -105,12 +104,15 @@ class TranslationTask(FairseqTask): return cls(args, src_dict, tgt_dict) - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) @@ -133,29 +135,28 @@ class TranslationTask(FairseqTask): src_datasets = [] tgt_datasets = [] - for dk, data_path in enumerate(self.args.data): - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else '') - # infer langcode - src, tgt = self.args.source_lang, self.args.target_lang - if split_exists(split_k, src, tgt, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) - elif split_exists(split_k, tgt, src, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) - else: - if k > 0 or dk > 0: - break - else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) - - src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) - tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) - - print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) - - if not combine: + # infer langcode + src, tgt = self.args.source_lang, self.args.target_lang + if split_exists(split_k, src, tgt, src, data_path): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) + elif split_exists(split_k, tgt, src, src, data_path): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) + else: + if k > 0: break + else: + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + + src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) + tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) + + print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) + + if not combine: + break assert len(src_datasets) == len(tgt_datasets) diff --git a/tests/test_train.py b/tests/test_train.py index d86db8d1c..1c7c43862 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,10 +68,13 @@ class TestLoadCheckpoint(unittest.TestCase): [p.start() for p in self.applied_patches] def test_load_partial_checkpoint(self): + with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + with patch('train.reload_train', return_value=epoch_itr): + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) + self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) @@ -86,7 +89,8 @@ class TestLoadCheckpoint(unittest.TestCase): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + with patch('train.reload_train', return_value=epoch_itr): + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) @@ -98,7 +102,7 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) self.patches['os.path.isfile'].return_value = False - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) diff --git a/train.py b/train.py index 52d7f7967..653d58a95 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,9 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load dataset splits - load_dataset_splits(args, task) + task.load_dataset(args.train_subset, combine=True, epoch=0) + for valid_sub_split in args.valid_subset.split(','): + task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build model and criterion model = task.build_model(args) @@ -64,15 +66,16 @@ def main(args, init_distributed=False): args.max_sentences, )) + max_positions = utils.resolve_max_positions( + task.max_positions(), + model.max_positions(), + ) # Initialize dataloader epoch_itr = task.get_batch_iterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, - max_positions=utils.resolve_max_positions( - task.max_positions(), - model.max_positions(), - ), + max_positions=max_positions, ignore_invalid_inputs=True, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, @@ -82,7 +85,7 @@ def main(args, init_distributed=False): ) # Load the latest checkpoint if one is available - load_checkpoint(args, trainer, epoch_itr) + load_checkpoint(args, trainer, epoch_itr, max_positions, task) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -105,10 +108,34 @@ def main(args, init_distributed=False): # save checkpoint if epoch_itr.epoch % args.save_interval == 0: save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + + epoch_itr = reload_train(args, epoch_itr, max_positions, task) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum)) +def reload_train(args, epoch_itr, max_positions, task): + # nothing needs to be done when the dataset is not sharded. + if len(args.data.split(":")) == 1: + return epoch_itr + print("| Reloading shard of train data at epoch: ", epoch_itr.epoch) + task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch) + epoch_itr = task.get_batch_iterator( + dataset=task.dataset(args.train_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=True, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, + num_workers=args.num_workers, + epoch=epoch_itr.epoch, + ) + return epoch_itr + + def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches @@ -335,9 +362,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): os.remove(old_chk) -def load_checkpoint(args, trainer, epoch_itr): +def load_checkpoint(args, trainer, epoch_itr, max_positions, task): """Load a checkpoint and replay dataloader to match.""" - # Only rank 0 should attempt to create the required dir if args.distributed_rank == 0: os.makedirs(args.save_dir, exist_ok=True) @@ -351,7 +377,14 @@ def load_checkpoint(args, trainer, epoch_itr): eval(args.optimizer_overrides)) if extra_state is not None: # replay train iterator to match checkpoint - epoch_itr.load_state_dict(extra_state['train_iterator']) + epoch_itr_state = extra_state['train_iterator'] + + # If the loaded checkpoint is not at epoch 0, reload train dataset, + # as it could be potentially sharded. + if epoch_itr_state['epoch'] != 0: + epoch_itr = reload_train(args, epoch_itr, max_positions, task) + + epoch_itr.load_state_dict(epoch_itr_state) print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) @@ -366,19 +399,6 @@ def load_checkpoint(args, trainer, epoch_itr): return False -def load_dataset_splits(args, task): - task.load_dataset(args.train_subset, combine=True) - for split in args.valid_subset.split(','): - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') - try: - task.load_dataset(split_k, combine=False) - except FileNotFoundError as e: - if k > 0: - break - raise e - - def distributed_main(i, args, start_rank=0): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn