Force certain optimizers to set --fp16-no-flatten-grads (#1010)

Summary:
When training with `--fp16` we usually flatten the grads since it's faster. But flat grads are not semantically equivalent for certain optimizers (e.g., Adafactor, LAMB), thus the user needed to be aware of this and set `--fp16-no-flatten-grads`. Let's raise a RuntimeError in this case instead.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1010

Differential Revision: D19575773

Pulled By: myleott

fbshipit-source-id: bac99c3026f9870e6127e0fa55f70e8a3e4507dc
This commit is contained in:
Myle Ott 2020-01-28 08:01:08 -08:00 committed by Facebook Github Bot
parent c2671c1720
commit 61aad8f9cd
13 changed files with 109 additions and 13 deletions

View File

@ -41,3 +41,7 @@ class Adadelta(FairseqOptimizer):
'eps': self.args.adadelta_eps,
'weight_decay': self.args.weight_decay,
}
@property
def supports_flat_params(self):
return True

View File

@ -102,6 +102,10 @@ class Adafactor(torch.optim.Optimizer):
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return False
def _get_lr(self, param_group, param_state):
rel_step_sz = param_group['lr']
if param_group['relative_step']:

View File

@ -34,3 +34,7 @@ class Adagrad(FairseqOptimizer):
'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
@property
def supports_flat_params(self):
return True

View File

@ -124,6 +124,10 @@ class Adam(torch.optim.Optimizer):
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.

View File

@ -88,6 +88,10 @@ class Adamax(torch.optim.Optimizer):
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.

View File

@ -5,6 +5,8 @@
import torch
from fairseq import utils
class FairseqOptimizer(object):
@ -86,10 +88,7 @@ class FairseqOptimizer(object):
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return torch.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None))
return utils.clip_grad_norm_(self.params, max_norm)
def step(self, closure=None):
"""Performs a single optimization step."""
@ -107,5 +106,15 @@ class FairseqOptimizer(object):
return self.optimizer.supports_memory_efficient_fp16
return False
@property
def supports_flat_params(self):
"""
Whether the optimizer supports collapsing of the model
parameters/gradients into a single contiguous Tensor.
"""
if hasattr(self.optimizer, 'supports_flat_params'):
return self.optimizer.supports_flat_params
return False
def average_params(self):
pass

View File

@ -157,9 +157,9 @@ class _FP16OptimizerMixin(object):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
if self.has_flat_params:
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
grad_norm = utils.clip_grad_norm_([self.fp32_params.grad.data], max_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.fp32_params, max_norm)
grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
@ -250,6 +250,11 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
else:
fp32_optimizer = optim.build_optimizer(args, fp32_params)
if flatten and not fp32_optimizer.supports_flat_params:
raise RuntimeError(
'chosen optimizer does not support flat params, '
'please set --fp16-no-flatten-grads'
)
return cls(args, params, fp32_optimizer, fp32_params)
@property
@ -273,6 +278,10 @@ class _MemoryEfficientFP16OptimizerMixin(object):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(*args, **kwargs)
@property
def has_flat_params(self):
return False
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.wrapped_optimizer.state_dict()

View File

@ -90,6 +90,10 @@ class FusedAdamV1(torch.optim.Optimizer):
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None, grads=None, scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
@ -209,6 +213,10 @@ try:
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
"""Performs a single optimization step."""
loss = None

View File

@ -44,3 +44,7 @@ class FairseqLAMB(FairseqOptimizer):
'eps': self.args.lamb_eps,
'weight_decay': self.args.weight_decay,
}
@property
def supports_flat_params(self):
return False

View File

@ -49,6 +49,10 @@ class NAG(Optimizer):
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.

View File

@ -37,3 +37,7 @@ class SGD(FairseqOptimizer):
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
@property
def supports_flat_params(self):
return True

View File

@ -227,12 +227,21 @@ def item(tensor):
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def clip_grad_norm_(params, max_norm):
params = list(params)
if len(params) == 1:
p = params[0]
grad_norm = torch.norm(p)
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
p.mul_(clip_coef)
return grad_norm
elif max_norm > 0:
return torch.nn.utils.clip_grad_norm_(params, max_norm)
else:
return torch.sqrt(
sum(p.grad.data.norm()**2 for p in params if p.grad is not None)
)
def fill_with_neg_inf(t):

View File

@ -660,7 +660,7 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()):
train.main(train_args)
class TestCommonOptions(unittest.TestCase):
class TestOptimizers(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
@ -688,6 +688,35 @@ class TestCommonOptions(unittest.TestCase):
])
generate_main(data_dir)
@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
def test_flat_grads(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_flat_grads') as data_dir:
# Use just a bit of data and tiny model to keep this test runtime reasonable
create_dummy_data(data_dir, num_examples=10, maxlen=5)
preprocess_translation_data(data_dir)
with self.assertRaises(RuntimeError):
# adafactor isn't compatible with flat grads, which
# are used by default with --fp16
train_translation_model(data_dir, 'lstm', [
'--required-batch-size-multiple', '1',
'--encoder-layers', '1',
'--encoder-hidden-size', '32',
'--decoder-layers', '1',
'--optimizer', 'adafactor',
'--fp16',
])
# but it should pass once we set --fp16-no-flatten-grads
train_translation_model(data_dir, 'lstm', [
'--required-batch-size-multiple', '1',
'--encoder-layers', '1',
'--encoder-hidden-size', '32',
'--decoder-layers', '1',
'--optimizer', 'adafactor',
'--fp16',
'--fp16-no-flatten-grads',
])
def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False):
def _create_dummy_data(filename):