optional limit on total training time (#2333)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/2333

This change adds a new option (`--stop-time-hours`) which if specified limits the total training time to that number of hours. In order to stop training within the inner training loop (after the first update exceeding the time limit) the starting time is stored on the trainer.

In addition, in order to persist the training time when when restoring from checkpoints (important because training runs are sometimes killed due to resource constraints), training time already completed is stored as extra state in the checkpoints (though this change is backward compatible with existing checkpoints).

Reviewed By: myleott

Differential Revision: D22573166

fbshipit-source-id: 01c59274a1c196acc8a3a0243814167e1d368b1a
This commit is contained in:
James Cross 2020-07-16 17:43:42 -07:00 committed by Facebook GitHub Bot
parent 75d354c92b
commit 3655cf266e
3 changed files with 33 additions and 2 deletions

View File

@ -455,6 +455,8 @@ def add_optimization_args(parser):
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--stop-time-hours', default=0, type=float, metavar='N',
help='force stop training after specified cumulative time (if >0)')
group.add_argument('--clip-norm', default=0.0, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--sentence-avg', action='store_true',

View File

@ -11,6 +11,7 @@ import contextlib
from itertools import chain
import logging
import sys
import time
from typing import Any, Dict, List
import torch
@ -110,6 +111,10 @@ class Trainer(object):
metrics.log_start_time("wall", priority=790, round=0)
self._start_time = time.time()
self._previous_training_time = 0
self._cumulative_training_time = None
def reinitialize(self):
"""Reinitialize the Trainer, typically after model params change."""
self._lr_scheduler = None
@ -218,6 +223,7 @@ class Trainer(object):
"""Save all training state in a checkpoint file."""
if self.is_data_parallel_master: # only save one checkpoint
extra_state["metrics"] = metrics.state_dict()
extra_state["previous_training_time"] = self.cumulative_training_time()
checkpoint_utils.save_state(
filename,
self.args,
@ -291,6 +297,10 @@ class Trainer(object):
)
)
if "previous_training_time" in extra_state:
self._previous_training_time = extra_state["previous_training_time"]
self._start_time = time.time()
self.lr_step(epoch)
if "metrics" in extra_state and not reset_meters:
@ -468,9 +478,11 @@ class Trainer(object):
# gather logging outputs from all replicas
if self._sync_stats():
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
train_time = self._local_cumulative_training_time()
logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
)
self._cumulative_training_time = total_train_time / self.data_parallel_world_size
overflow = False
try:
@ -716,6 +728,17 @@ class Trainer(object):
def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None)
def cumulative_training_time(self):
if self._cumulative_training_time is None:
# single GPU
return self._local_cumulative_training_time()
else:
return self._cumulative_training_time
def _local_cumulative_training_time(self):
"""Aggregate training time in seconds."""
return time.time() - self._start_time + self._previous_training_time
def _prepare_sample(self, sample):
if sample == "DUMMY":
raise Exception(

View File

@ -229,6 +229,12 @@ def train(args, trainer, task, epoch_itr):
valid_losses, should_stop = validate_and_save(
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
)
if args.stop_time_hours > 0:
elapsed_hours = trainer.cumulative_training_time() / (60 * 60)
if elapsed_hours > args.stop_time_hours:
should_stop = True
if should_stop:
break