Fix resuming from FP16 checkpoints (#424)

Summary:
This was broken in 03a57de.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/424

Differential Revision: D13557540

Pulled By: myleott

fbshipit-source-id: 62deda5353032aff20d35d046b0bb843da44d27c
This commit is contained in:
Myle Ott 2018-12-27 21:20:02 -08:00 committed by Facebook Github Bot
parent 31a43973e4
commit 58dd1862f6
2 changed files with 21 additions and 10 deletions

View File

@ -63,6 +63,20 @@ class ConvertToFP32(object):
self.params = params
self.itr = map(convert_to_fp32, params)
@staticmethod
def wrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = ConvertToFP32(group['params'])
@staticmethod
def unwrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
def __len__(self):
return len(self.params)
@ -145,7 +159,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss):
loss = loss * self.scaler.loss_scale
@ -194,18 +210,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
self._unscale_grads()
# convert params and grads to FP32 (lazily)
for group in self.wrapped_optimizer.optimizer.param_groups:
group['params'] = ConvertToFP32(group['params'])
ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
self.wrapped_optimizer.step(closure)
# convert params back to FP16
for group in self.wrapped_optimizer.optimizer.param_groups:
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""

View File

@ -114,8 +114,9 @@ class Trainer(object):
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.get_model())
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.get_model(),
)
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()