optimize mixed precision (#1248)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Implements the multiply_factor optimization used in memory efficient fp16 training to mixed precision training. The methods multiply_grads and clip_grad_norm do not touch each gradient, but rather a "multiply factor" that is then factored in when unscaling gradients.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1248

Reviewed By: myleott

Differential Revision: D23201396

Pulled By: andersonic

fbshipit-source-id: 6c6f64542893e0ecac72e132464bb334dcb9874d
This commit is contained in:
Jun Ru Anderson 2020-08-19 16:03:14 -07:00 committed by Facebook GitHub Bot
parent 77983ee1a5
commit 68c87f0abf
2 changed files with 104 additions and 19 deletions

View File

@ -75,12 +75,8 @@ class _FP16OptimizerMixin(object):
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
def _sync_fp16_grads_to_fp32(self):
if self._needs_sync:
if self.scaler is not None:
# correct for dynamic loss scaler
multiply_grads /= self.scaler.loss_scale
# copy FP16 grads to FP32
if self.has_flat_params:
offset = 0
@ -91,20 +87,18 @@ class _FP16OptimizerMixin(object):
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
self.fp32_params.grad.data.mul_(multiply_grads)
else:
for p, p32 in zip(self.fp16_params, self.fp32_params):
if not p.requires_grad:
continue
if p.grad is not None:
p32.grad.data.copy_(p.grad.data)
p32.grad.data.mul_(multiply_grads)
else:
p32.grad = torch.zeros_like(p.data, dtype=torch.float)
self._needs_sync = False
def _sync_fp32_grads_to_fp16(self):
def _sync_fp32_params_to_fp16(self):
# copy FP32 params back into FP16 model
if self.has_flat_params:
offset = 0
@ -120,36 +114,47 @@ class _FP16OptimizerMixin(object):
continue
p.data.copy_(p32.data)
def _unscale_grads(self):
self._sync_fp16_grads_to_fp32()
if self._multiply_factor != 1.:
self.fp32_optimizer.multiply_grads(self._multiply_factor)
self._multiply_factor = 1.
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
elif self.has_flat_params:
self.fp32_params.grad.data.mul_(c)
else:
for p32 in self.fp32_params:
p32.grad.data.mul_(c)
self._multiply_factor *= c
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params, max_norm, aggregate_norm_fn)
# detect overflow and adjust loss scale
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(0, aggregate_norm_fn)
if self.scaler is not None:
if grad_norm > max_norm > 0.0:
self._multiply_factor *= max_norm / grad_norm
self.scaler.check_overflow(grad_norm)
else:
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
self._multiply_factor *= clip_coef
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
if self.supports_step_with_scale:
self.fp32_optimizer.step(closure, scale=(1. / self._multiply_factor))
else:
self._unscale_grads()
self.fp32_optimizer.step(closure)
if self.scaler is not None:
self.scaler.update()
self._sync_fp32_grads_to_fp16()
self._sync_fp32_params_to_fp16()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
@ -162,6 +167,9 @@ class _FP16OptimizerMixin(object):
p32.grad.zero_()
self._needs_sync = False
if self.scaler is not None:
self._multiply_factor = 1. / float(self.scaler.loss_scale)
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
"""

View File

@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
import unittest
import torch
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
class TestGradientScaling(unittest.TestCase):
def setUp(self):
self.x = torch.tensor([2.0]).cuda().half()
weight = 3.0
bias = 5.0
self.error = 1.0
self.target = torch.tensor([self.x * weight + bias + self.error]).cuda().half()
self.loss_fn = torch.nn.L1Loss()
self.model = torch.nn.Linear(1, 1)
self.model.weight.data = torch.tensor([[weight]])
self.model.bias.data = torch.tensor([bias])
self.model.cuda().half()
self.params = list(self.model.parameters())
self.namespace_dls = argparse.Namespace(
optimizer='adam',
lr=[0.1],
adam_betas='(0.9, 0.999)',
adam_eps=1e-8,
weight_decay=0.0,
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4
)
def run_iter(self, model, params, optimizer):
optimizer.zero_grad()
y = model(self.x)
loss = self.loss_fn(y, self.target)
optimizer.backward(loss)
self.assertEqual(loss, torch.tensor(1., device='cuda:0', dtype=torch.float16))
grad_norm = optimizer.clip_grad_norm(0)
self.assertAlmostEqual(grad_norm.item(), 2.2361, 4)
optimizer.step()
self.assertEqual(model.weight, torch.tensor([[3.0996]], device='cuda:0', dtype=torch.float16, requires_grad=True))
self.assertEqual(model.bias, torch.tensor([5.1016], device='cuda:0', dtype=torch.float16, requires_grad=True))
self.assertEqual(optimizer.scaler.loss_scale, 2.)
def test_mixed_precision(self):
model = copy.deepcopy(self.model)
params = list(model.parameters())
optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params)
self.run_iter(model, params, optimizer)
self.assertTrue(torch.all(optimizer.fp32_params.eq(torch.tensor([3.1000, 5.1000], device='cuda:0', requires_grad=True))))
def test_memory_efficient(self):
model = copy.deepcopy(self.model)
params = list(model.parameters())
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.namespace_dls, params)
self.run_iter(model, params, optimizer)
if __name__ == '__main__':
unittest.main()