diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 319374481..629f7b162 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -79,11 +79,12 @@ class EpochBatchIterator(object): num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). + epoch (int, optional): The epoch to start the iterator from. """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, + num_workers=0, epoch=0, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -94,7 +95,7 @@ class EpochBatchIterator(object): self.shard_id = shard_id self.num_workers = num_workers - self.epoch = 0 + self.epoch = epoch self._cur_epoch_itr = None self._next_epoch_itr = None self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index dad60932e..799b1c034 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('data', help='path to data directory') + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner') parser.add_argument('--tokens-per-sample', default=512, type=int, help='max number of total tokens over all segments' ' per sample') @@ -106,12 +107,16 @@ class CrossLingualLMTask(FairseqTask): return cls(args, dictionary) - def _load_single_lang_dataset(self, split): + def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') - path = os.path.join(self.args.data, split_k) + path = os.path.join(data_path, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) @@ -124,7 +129,7 @@ class CrossLingualLMTask(FairseqTask): if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # Since we append each block with the classification_token, # we need to effectively create blocks of length @@ -136,7 +141,7 @@ class CrossLingualLMTask(FairseqTask): ) ) - print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) + print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] @@ -147,7 +152,7 @@ class CrossLingualLMTask(FairseqTask): return dataset, sizes - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) @@ -162,7 +167,7 @@ class CrossLingualLMTask(FairseqTask): # Datasets are expected to be in "split.lang" format (Eg: train.en) language_split = '{}.{}'.format(split, lang) - block_dataset, sizes = self._load_single_lang_dataset(split=language_split) + block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch) dataset_map[lang] = MaskedLMDataset( dataset=block_dataset, @@ -182,6 +187,6 @@ class CrossLingualLMTask(FairseqTask): dataset_map, default_key=self.default_key ) print('| {} {} {} examples'.format( - self.args.data, split, len(self.datasets[split]) + self.args.data.split(':')[epoch], split, len(self.datasets[split]) ) - ) \ No newline at end of file + ) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 09e409151..607baf3c0 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -92,7 +92,7 @@ class FairseqTask(object): def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, + seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. @@ -118,6 +118,7 @@ class FairseqTask(object): num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). + epoch (int, optional): The epoch to start the iterator from. Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the @@ -149,6 +150,7 @@ class FairseqTask(object): num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, + epoch=epoch, ) def build_model(self, args): diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 222cb75b4..2fe086ca4 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -104,7 +104,9 @@ class LanguageModelingTask(FairseqTask): dictionary = None output_dictionary = None if args.data: - dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) + paths = args.data.split(':') + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt')) print('| dictionary: {} types'.format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: @@ -136,7 +138,7 @@ class LanguageModelingTask(FairseqTask): return model - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: @@ -145,9 +147,13 @@ class LanguageModelingTask(FairseqTask): loaded_datasets = [] + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') - path = os.path.join(self.args.data, split_k) + path = os.path.join(data_path, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) @@ -160,7 +166,7 @@ class LanguageModelingTask(FairseqTask): if k > 0: break else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) loaded_datasets.append( TokenBlockDataset( @@ -170,7 +176,7 @@ class LanguageModelingTask(FairseqTask): ) ) - print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) + print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if not combine: break diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index cf720f5a7..8f5f60e60 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -135,7 +135,9 @@ class MultilingualTranslationTask(FairseqTask): # load dictionaries dicts = OrderedDict() for lang in sorted_langs: - dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang))) + paths = args.data.split(':') + assert len(paths) > 0 + dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang))) if len(dicts) > 0: assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() @@ -185,11 +187,15 @@ class MultilingualTranslationTask(FairseqTask): new_tgt_bos=new_tgt_bos, ) - def load_dataset(self, split, **kwargs): + def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + def split_exists(split, src, tgt, lang): - filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): @@ -210,17 +216,17 @@ class MultilingualTranslationTask(FairseqTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) - print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) + print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split('-') diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index 34941662d..2dc1cd1e9 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -132,14 +132,18 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) return cls(args, dicts, training) - def load_dataset(self, split, **kwargs): + def load_dataset(self, split, epoch=0, **kwargs): """Load a dataset split.""" + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + def split_exists(split, src, tgt, lang): if src is not None: - filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) + filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) else: - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedDataset.exists(filename): @@ -162,16 +166,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if split_exists(split, src, tgt, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) elif split_exists(split, tgt, src, src): - prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src)) + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) else: continue src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) - print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair]))) + print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) # back translation datasets backtranslate_datasets = {} @@ -179,8 +183,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): for lang_pair in self.args.lang_pairs: src, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): - raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data)) - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) + raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) dataset = indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, @@ -216,7 +220,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ).collater, ) print('| backtranslate-{}: {} {} {} examples'.format( - tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]), + tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair] @@ -227,7 +231,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): _, tgt = lang_pair.split('-') if not split_exists(split, tgt, None, tgt): continue - filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt)) + filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt)) tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( @@ -255,7 +259,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): tgt_lang=tgt, ) print('| denoising-{}: {} {} {} examples'.format( - tgt, self.args.data, split, len(noising_datasets[lang_pair]), + tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index cab8b5369..785aa86a5 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -48,7 +48,8 @@ class TranslationTask(FairseqTask): def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off - parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)') + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner') parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', @@ -84,19 +85,17 @@ class TranslationTask(FairseqTask): args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) - # upgrade old checkpoints - if isinstance(args.data, str): - args.data = [args.data] - + paths = args.data.split(':') + assert len(paths) > 0 # find language pair automatically if args.source_lang is None or args.target_lang is None: - args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) + args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0]) if args.source_lang is None or args.target_lang is None: raise Exception('Could not infer language pair, please provide it explicitly') # load dictionaries - src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang))) - tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang))) + src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang))) + tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang))) assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() @@ -105,12 +104,15 @@ class TranslationTask(FairseqTask): return cls(args, src_dict, tgt_dict) - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) @@ -133,29 +135,28 @@ class TranslationTask(FairseqTask): src_datasets = [] tgt_datasets = [] - for dk, data_path in enumerate(self.args.data): - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else '') - # infer langcode - src, tgt = self.args.source_lang, self.args.target_lang - if split_exists(split_k, src, tgt, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) - elif split_exists(split_k, tgt, src, src, data_path): - prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) - else: - if k > 0 or dk > 0: - break - else: - raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) - - src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) - tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) - - print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) - - if not combine: + # infer langcode + src, tgt = self.args.source_lang, self.args.target_lang + if split_exists(split_k, src, tgt, src, data_path): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt)) + elif split_exists(split_k, tgt, src, src, data_path): + prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) + else: + if k > 0: break + else: + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) + + src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) + tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) + + print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) + + if not combine: + break assert len(src_datasets) == len(tgt_datasets) diff --git a/tests/test_train.py b/tests/test_train.py index d86db8d1c..1c7c43862 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,10 +68,13 @@ class TestLoadCheckpoint(unittest.TestCase): [p.start() for p in self.applied_patches] def test_load_partial_checkpoint(self): + with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + with patch('train.reload_train', return_value=epoch_itr): + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) + self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) @@ -86,7 +89,8 @@ class TestLoadCheckpoint(unittest.TestCase): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + with patch('train.reload_train', return_value=epoch_itr): + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) @@ -98,7 +102,7 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) self.patches['os.path.isfile'].return_value = False - train.load_checkpoint(self.args_mock, trainer, epoch_itr) + train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) diff --git a/train.py b/train.py index 52d7f7967..653d58a95 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,9 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load dataset splits - load_dataset_splits(args, task) + task.load_dataset(args.train_subset, combine=True, epoch=0) + for valid_sub_split in args.valid_subset.split(','): + task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build model and criterion model = task.build_model(args) @@ -64,15 +66,16 @@ def main(args, init_distributed=False): args.max_sentences, )) + max_positions = utils.resolve_max_positions( + task.max_positions(), + model.max_positions(), + ) # Initialize dataloader epoch_itr = task.get_batch_iterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, - max_positions=utils.resolve_max_positions( - task.max_positions(), - model.max_positions(), - ), + max_positions=max_positions, ignore_invalid_inputs=True, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, @@ -82,7 +85,7 @@ def main(args, init_distributed=False): ) # Load the latest checkpoint if one is available - load_checkpoint(args, trainer, epoch_itr) + load_checkpoint(args, trainer, epoch_itr, max_positions, task) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -105,10 +108,34 @@ def main(args, init_distributed=False): # save checkpoint if epoch_itr.epoch % args.save_interval == 0: save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + + epoch_itr = reload_train(args, epoch_itr, max_positions, task) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum)) +def reload_train(args, epoch_itr, max_positions, task): + # nothing needs to be done when the dataset is not sharded. + if len(args.data.split(":")) == 1: + return epoch_itr + print("| Reloading shard of train data at epoch: ", epoch_itr.epoch) + task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch) + epoch_itr = task.get_batch_iterator( + dataset=task.dataset(args.train_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=True, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, + num_workers=args.num_workers, + epoch=epoch_itr.epoch, + ) + return epoch_itr + + def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches @@ -335,9 +362,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): os.remove(old_chk) -def load_checkpoint(args, trainer, epoch_itr): +def load_checkpoint(args, trainer, epoch_itr, max_positions, task): """Load a checkpoint and replay dataloader to match.""" - # Only rank 0 should attempt to create the required dir if args.distributed_rank == 0: os.makedirs(args.save_dir, exist_ok=True) @@ -351,7 +377,14 @@ def load_checkpoint(args, trainer, epoch_itr): eval(args.optimizer_overrides)) if extra_state is not None: # replay train iterator to match checkpoint - epoch_itr.load_state_dict(extra_state['train_iterator']) + epoch_itr_state = extra_state['train_iterator'] + + # If the loaded checkpoint is not at epoch 0, reload train dataset, + # as it could be potentially sharded. + if epoch_itr_state['epoch'] != 0: + epoch_itr = reload_train(args, epoch_itr, max_positions, task) + + epoch_itr.load_state_dict(epoch_itr_state) print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) @@ -366,19 +399,6 @@ def load_checkpoint(args, trainer, epoch_itr): return False -def load_dataset_splits(args, task): - task.load_dataset(args.train_subset, combine=True) - for split in args.valid_subset.split(','): - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') - try: - task.load_dataset(split_k, combine=False) - except FileNotFoundError as e: - if k > 0: - break - raise e - - def distributed_main(i, args, start_rank=0): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn