diff --git a/fairseq/options.py b/fairseq/options.py index beb352473..bfbb9765e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -233,6 +233,8 @@ def add_dataset_args(parser, train=False, gen=False): group.add_argument('--max-sentences-valid', type=int, metavar='N', help='maximum number of sentences in a validation batch' ' (defaults to --max-sentences)') + group.add_argument('--curriculum', default=0, type=int, metavar='N', + help='don\'t shuffle batches for first N epochs') if gen: group.add_argument('--gen-subset', default='test', metavar='SPLIT', help='data subset to generate (train, valid, test)') diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index bb8be0bda..adc0161ce 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -90,9 +90,9 @@ class FairseqTask(object): return self.datasets[split] def get_batch_iterator( - self, dataset, max_tokens=None, max_sentences=None, max_positions=None, - ignore_invalid_inputs=False, required_batch_size_multiple=1, - seed=1, num_shards=1, shard_id=0, num_workers=0, + self, dataset, max_tokens=None, max_sentences=None, max_positions=None, + ignore_invalid_inputs=False, required_batch_size_multiple=1, + seed=1, num_shards=1, shard_id=0, num_workers=0, ): """ Get an iterator that yields batches of data from the given dataset. diff --git a/train.py b/train.py index d678bb7b8..30a441889 100644 --- a/train.py +++ b/train.py @@ -120,13 +120,15 @@ def main(args, init_distributed=False): def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" - # Update parameters every N batches - - # Initialize data iterator - itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] + + # Initialize data iterator + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=args.fix_batches_to_gpus, + shuffle=(epoch_itr.epoch >= args.curriculum), + ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple',