Clean up sharded train iterator

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/586

Differential Revision: D15372949

Pulled By: myleott

fbshipit-source-id: c1cf1c645e8d55fc8568f23a47c45677ac9ab1da
This commit is contained in:
Myle Ott 2019-05-16 21:00:17 -07:00 committed by Facebook Github Bot
parent fca32e0565
commit 3bfbb49ba5
6 changed files with 76 additions and 88 deletions

View File

@ -87,9 +87,9 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
def load_checkpoint(args, trainer):
"""Load a checkpoint and restore the training iterator."""
# only one worker should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
@ -97,32 +97,26 @@ def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
extra_state = trainer.load_checkpoint(
checkpoint_path,
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
)
epoch_itr.load_state_dict(epoch_itr_state)
if extra_state is not None and 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
if extra_state is not None:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
epoch_itr.load_state_dict(itr_state)
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False
epoch_itr = trainer.get_train_iterator(epoch=0)
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path):
@ -165,28 +159,6 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
return ensemble, args
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.

View File

@ -76,7 +76,8 @@ class EpochBatchIterator(object):
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): The epoch to start the iterator from.
epoch (int, optional): the epoch to start the iterator from
(default: 0).
"""
def __init__(

View File

@ -118,7 +118,8 @@ class FairseqTask(object):
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): The epoch to start the iterator from.
epoch (int, optional): the epoch to start the iterator from
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the

View File

@ -124,8 +124,9 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
filename, self.args, self.get_model().state_dict(), self.criterion,
self.optimizer, self.lr_scheduler, self._num_updates,
self._optim_history, extra_state,
)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
@ -165,17 +166,48 @@ class Trainer(object):
self._num_updates = last_optim['num_updates']
if extra_state is not None and 'train_meters' in extra_state:
self.meters.update(extra_state['train_meters'])
del extra_state['train_meters']
if extra_state is not None:
epoch = extra_state['train_iterator']['epoch']
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
filename, epoch, self.get_num_updates()))
# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()
self.lr_step(epoch)
self.lr_step_update(self.get_num_updates())
if 'train_meters' in extra_state:
self.meters.update(extra_state['train_meters'])
del extra_state['train_meters']
# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()
else:
print('| no existing checkpoint found {}'.format(filename))
return extra_state
def get_train_iterator(self, epoch, combine=True):
"""Return an EpochBatchIterator over the training set for a given epoch."""
print('| loading train data for epoch {}'.format(epoch))
self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
return self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset),
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.args.distributed_world_size,
shard_id=self.args.distributed_rank,
num_workers=self.args.num_workers,
epoch=epoch,
)
def train_step(self, samples, dummy_batch=False, raise_oom=False):
"""Do forward, backward and parameter update."""
if self._dummy_batch is None:

View File

@ -69,10 +69,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
@ -99,10 +98,9 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
@ -112,9 +110,10 @@ class TestLoadCheckpoint(unittest.TestCase):
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches['os.path.isfile'].return_value = False
checkpoint_utils.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)

View File

@ -64,27 +64,9 @@ def main(args, init_distributed=False):
args.max_sentences,
))
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
# Initialize dataloader
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
)
# Load the latest checkpoint if one is available
checkpoint_utils.load_checkpoint(
args, trainer, epoch_itr, max_positions, task)
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
@ -106,10 +88,11 @@ def main(args, init_distributed=False):
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
checkpoint_utils.save_checkpoint(
args, trainer, epoch_itr, valid_losses[0])
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
epoch_itr = checkpoint_utils.reload_train(args, epoch_itr, max_positions, task)
if ':' in args.data:
# sharded data: get train iterator for next epoch
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))