Add optimizer history to checkpoints (and rearrange criterions slightly)

This commit is contained in:
Myle Ott 2017-09-27 19:33:19 -07:00
parent 48631f7a3c
commit e432459b37
6 changed files with 98 additions and 49 deletions

View File

@ -18,14 +18,14 @@ class CrossEntropyCriterion(FairseqCriterion):
super().__init__()
self.padding_idx = padding_idx
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return loss / self.denom
return loss
def aggregate(self, losses):
return sum(losses) / math.log(2)

View File

@ -11,13 +11,17 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self):
super().__init__()
def prepare(self, samples):
"""Prepare criterion for DataParallel training."""
def grad_denom(self, samples):
"""Gradient normalization term for DataParallel training."""
raise NotImplementedError
def prepare(self, model, sample):
"""Apply criterion-specific modifications to the sample."""
return sample
def forward(self, net_output, sample):
"""Compute the loss for the given sample and network output."""
raise NotImplementedError

View File

@ -49,14 +49,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.padding_idx = padding_idx
self.weights = weights
def prepare(self, samples):
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return loss / self.denom
return loss
def aggregate(self, losses):
return sum(losses) / math.log(2)

View File

@ -32,7 +32,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
(prefixed with `_async_`), which run on each process in parallel.
"""
def __init__(self, args, model, device_ids=None,
def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count()))
@ -42,16 +42,17 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory()
nccl_uid = nccl.get_unique_id()
self.criterion = criterion
Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model,
nccl_uid=nccl_uid)
criterion=criterion, nccl_uid=nccl_uid)
for rank in range(self.num_replicas)
])
self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, nccl_uid):
def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
"""Initialize child processes."""
self.args = args
@ -64,8 +65,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize NCCL
nccl.initialize(self.num_replicas, nccl_uid, device_id)
# copy model to current device
# copy model and criterion to current device
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer
self.optimizer = NAG(self.model.parameters(), lr=self.args.lr,
@ -104,8 +106,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
batch_offset=batch_offset, val_loss=val_loss).gen()
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss):
utils.save_checkpoint(args, epoch, batch_offset, self.model,
self.optimizer, self.lr_scheduler, val_loss)
utils.save_checkpoint(args, epoch, batch_offset, self.model, self.criterion,
self.optimizer, self.lr_scheduler, val_loss, self._optim_history)
def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process."""
@ -117,13 +119,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return epoch, batch_offset
def _async_load_checkpoint(self, rank, device_id, filename):
return utils.load_checkpoint(filename, self.model, self.optimizer,
self.lr_scheduler, cuda_device=device_id)
epoch, batch_offset, self._optim_history = utils.load_checkpoint(
filename, self.model, self.criterion, self.optimizer, self.lr_scheduler,
cuda_device=device_id)
return epoch, batch_offset
def train_step(self, samples, criterion):
def train_step(self, samples):
"""Do forward, backward and gradient step in parallel."""
assert isinstance(criterion, FairseqCriterion)
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples = False
@ -133,31 +135,36 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
criterion.prepare(samples)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', criterion=criterion)
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
for rank in range(self.num_replicas)
]
# aggregate losses and gradient norms
losses, grad_norms = Future.gen_tuple_list(losses)
loss = criterion.aggregate(losses)
loss = self.criterion.aggregate(losses)
return loss, grad_norms[0]
def _async_train_step(self, rank, device_id, criterion):
def _async_train_step(self, rank, device_id, grad_denom):
self.model.train()
# zero grads even if net_input is None, since we will all-reduce them
# zero grads even if self._sample is None, since we will all-reduce them
self.optimizer.zero_grad()
# calculate loss and grads
loss = 0
if self._sample is not None:
self._sample = self.criterion.prepare(self.model, self._sample)
net_output = self.model(**self._sample['net_input'])
loss_ = criterion(net_output, self._sample)
loss_ = self.criterion(net_output, self._sample)
if grad_denom is not None:
loss_ /= grad_denom
loss_.backward()
loss = loss_.data[0]
@ -196,29 +203,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
flat_grads.div_(coef)
return norm
def valid_step(self, samples, criterion):
def valid_step(self, samples):
"""Do forward pass in parallel."""
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)
criterion.prepare(samples)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass
losses = [
self.call_async(rank, '_async_valid_step', criterion=criterion)
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
for rank in range(self.num_replicas)
]
# aggregate losses
loss = criterion.aggregate(Future.gen_list(losses))
loss = self.criterion.aggregate(Future.gen_list(losses))
return loss
def _async_valid_step(self, rank, device_id, criterion):
def _async_valid_step(self, rank, device_id, grad_denom):
if self._sample is None:
return 0
self.model.eval()
self._sample = self.criterion.prepare(self.model, self._sample)
net_output = self.model(**self._sample['net_input'])
loss = criterion(net_output, self._sample)
loss = self.criterion(net_output, self._sample)
if grad_denom is not None:
loss /= grad_denom
return loss.data[0]
def get_lr(self):

View File

@ -46,15 +46,23 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc())
def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, val_loss=None):
def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_scheduler,
val_loss=None, optim_history=None):
if optim_history is None:
optim_history = []
state_dict = {
'args': args,
'epoch': epoch,
'batch_offset': batch_offset,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_loss': lr_scheduler.best,
'val_loss': val_loss,
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer': optimizer.state_dict(),
'best_loss': lr_scheduler.best,
}
],
}
if batch_offset == 0:
@ -72,9 +80,9 @@ def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, v
torch_persistent_save(state_dict, last_filename)
def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename):
return 1, 0
return 1, 0, []
if cuda_device is None:
state = torch.load(filename)
else:
@ -82,16 +90,41 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
lr_scheduler.best = state['best_loss']
epoch = state['epoch'] + 1
batch_offset = state['batch_offset']
# only load optimizer and lr_scheduler if they match with the checkpoint
opt_str = ''
optim_history = state['optimizer_history']
last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(last_optim['optimizer'])
lr_scheduler.best = last_optim['best_loss']
opt_str = '; criterion: {}'.format(last_optim['criterion_name'])
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else ''
print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch, gpu_str))
return epoch, batch_offset
print('| loaded checkpoint {} (epoch {}{}){}'.format(filename, epoch, opt_str, gpu_str))
return epoch, batch_offset, optim_history
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': criterions.CrossEntropyCriterion.__name__,
'optimizer': state['optimizer'],
'best_loss': state['best_loss'],
},
]
del state['optimizer']
del state['best_loss']
return state
def load_ensemble_for_inference(filenames, data_path, split):

View File

@ -68,7 +68,7 @@ def main():
criterion = utils.build_criterion(args, dataset)
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model)
trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
@ -81,11 +81,11 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, num_gpus)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, criterion, dataset, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
if k == 0:
if not args.no_save:
# save checkpoint
@ -102,7 +102,7 @@ def main():
trainer.stop()
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
@ -121,7 +121,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss, grad_norm = trainer.train_step(sample, criterion)
loss, grad_norm = trainer.train_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
@ -160,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
gnorm_meter.avg))
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
def validate(args, epoch, trainer, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None,
@ -173,7 +173,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
with progress_bar(itr, desc, leave=False) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample, criterion)
loss = trainer.valid_step(sample)
loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)