Add --max-update

This commit is contained in:
Myle Ott 2018-03-28 07:56:11 -06:00
parent 70ecd80579
commit 3ea882c57b
2 changed files with 14 additions and 1 deletions

View File

@ -141,6 +141,8 @@ def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--sentence-avg', action='store_true',

View File

@ -71,6 +71,7 @@ def main(args):
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
@ -91,6 +92,9 @@ def main(args):
epoch += 1
batch_offset = 0
if trainer.get_num_updates() >= max_update:
break
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
@ -134,6 +138,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample)
@ -153,9 +158,15 @@ def train(args, trainer, dataset, epoch, batch_offset):
if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation
trainer.get_meter('wps').reset()
if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0:
# save mid-epoch checkpoints
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():