Average local optimizer param after warmup and during bmuf sync

Summary: We have seen that averaging the local param instead of doing reset or broadcast after warmup improves the WER.

Reviewed By: skritika

Differential Revision: D16739278

fbshipit-source-id: 75033d2d25f9a88fd6dd325d0d9d4c856d22d947
This commit is contained in:
Nayan Singhal 2019-09-12 10:56:16 -07:00 committed by Facebook Github Bot
parent 3e3fe72299
commit 1fd8943e94
3 changed files with 78 additions and 45 deletions

View File

@ -8,6 +8,7 @@ import types
import torch
import torch.optim
import torch.distributed as dist
from . import FairseqOptimizer, register_optimizer
@ -53,6 +54,17 @@ class FairseqAdam(FairseqOptimizer):
'weight_decay': self.args.weight_decay,
}
def average_params(self):
"""Reduce Params is only used during BMUF distributed training."""
state_dict = self.optimizer.state_dict()
total_gpus = float(dist.get_world_size())
for _, value in state_dict["state"].items():
value["exp_avg"] /= total_gpus
value["exp_avg_sq"] /= total_gpus
dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM)
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
class Adam(torch.optim.Optimizer):
"""Implements Adam algorithm.

View File

@ -31,6 +31,7 @@ class FairseqBMUF(FairseqOptimizer):
self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm
self.initial_state = self._optimizer.state_dict()
self.average_sync = self.args.average_sync
@staticmethod
def add_args(parser):
@ -62,6 +63,12 @@ class FairseqBMUF(FairseqOptimizer):
action="store_true",
help="Specify whether you want to use classical BM / Nesterov BM",
)
parser.add_argument(
"--average-sync",
default=True,
action="store_true",
help="Specify whether you want to average the local momentum after each sync",
)
@property
def optimizer(self):
@ -91,34 +98,50 @@ class FairseqBMUF(FairseqOptimizer):
"""Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm)
def average_params(self):
self._optimizer.average_params()
def _block_sync(self):
# Update the global model using local models from all GPUs.
if self._is_bmuf_iter():
if self.block_momentum != 0:
self._BM_before_sync()
# Update the global model using local models from all GPUs
# (Step-1) Calculate grad between previously synced model and
# currrent local model
if self.block_momentum != 0:
self._calc_grad()
self._allreduce_parameter()
# (Step-2) Average gradient from all GPUs
self._avg_grad_from_all_gpus()
if self.block_momentum != 0:
self._BM_after_sync()
# (Step-3) Calculate global momentum and update the global model
if self.block_momentum != 0:
self._update_global_model()
# (Step-4) Average local optimizer params
if self.average_sync:
self.average_params()
def _is_warmup_end(self):
# Check whether train iterations is equal to warmup iter
if self.get_num_updates() == self.warmup_iteration:
return True
return False
def _is_bmuf_iter(self):
# Check whether train iterations is equal to bmuf sync iter
if self.get_num_updates() % self.sync_iter == 0:
return True
return False
def _warmup_sync(self, rootRank=0):
# broadcast the local model to all GPUs
def _warmup_sync(self, root_rank=0):
# Broadcast the local model to all gpus
for param in self.params:
dist.broadcast(param.data, src=rootRank)
dist.broadcast(param.data, src=root_rank)
# Update local optimizer state
if self.average_sync:
self._optimizer.average_params()
else:
self._optimizer.load_state_dict(self.initial_state)
# Reset the local optimizer state and local bmuf related param
self._optimizer.load_state_dict(self.initial_state)
self._reset_local_data()
def step(self, closure=None):
@ -127,7 +150,7 @@ class FairseqBMUF(FairseqOptimizer):
self.set_num_updates(self.get_num_updates() + 1)
if self._is_warmup_end():
self._warmup_sync()
else:
elif self._is_bmuf_iter():
self._block_sync()
def zero_grad(self):
@ -144,61 +167,56 @@ class FairseqBMUF(FairseqOptimizer):
@torch.no_grad()
def _reset_local_data(self):
"""Resetting all the BMUF specific params."""
self.params_localprev = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads_localprev = [
p.data.new_zeros(p.data.size()) for p in self.params
]
self.grads_localprev = [p.data.new_zeros(p.data.size()) for p in self.params]
# (Step-0) Initialize global momentum parameters and store global copy on each gpu
self.global_params = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
# saving the global model locally for calculating gradient during bmuf sync
for param, copy_param in zip(self.params, self.params_localprev):
copy_param.copy_(param.data)
for param, global_param in zip(self.params, self.global_params):
global_param.copy_(param.data)
@torch.no_grad()
def _BM_before_sync(self):
"""Calculate grad between previously synced model and currrent local model."""
# prev_param is basically the global copy from the previously finished
def _calc_grad(self):
# global_params is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced
# model and currrent local model.
for index, (param, prev_param) in enumerate(
zip(self.params, self.params_localprev)
for index, (param, global_param) in enumerate(
zip(self.params, self.global_params)
):
self.grads_localprev[index] = prev_param - param.data
self.grads[index] = global_param - param.data
def _allreduce_parameter(self):
"""Average gradient from all the GPUs. """
def _avg_grad_from_all_gpus(self):
for index, param in enumerate(self.params):
sync_para = (
param.data if self.block_momentum == 0 else self.grads_localprev[index]
)
sync_para = param.data if self.block_momentum == 0 else self.grads[index]
sync_para /= float(dist.get_world_size())
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
@torch.no_grad()
def _BM_after_sync(self):
for index, (param, prev_param, smoothed_grad, grad) in enumerate(
def _update_global_model(self):
for index, (param, global_param, smoothed_grad, grad) in enumerate(
zip(
self.params,
self.params_localprev,
self.smoothed_grads_localprev,
# all machines would share the same value of smoothed_grad, since it is
self.global_params,
self.smoothed_grads,
# all gpus would share the same value of smoothed_grad, since it is
# always computed on synchronized gradients.
self.grads_localprev,
self.grads,
)
):
# prev_param is basically last syncrhornized parameter. though
# global_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy.
# smoothed_grad(t)=BM * smoothed_grad(t-1) + BM_lr*grad(t)
smoothed_grad = smoothed_grad * self.block_momentum + grad * self.block_lr
param.data.copy_(prev_param - smoothed_grad)
# smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
param.data.copy_(global_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before
# calculating the gradient
if self.use_nbm:
param.data.copy_(param.data - self.block_momentum * smoothed_grad)
# backup for the next synchronization.
self.smoothed_grads_localprev[index] = smoothed_grad
prev_param.copy_(param.data)
self.smoothed_grads[index] = smoothed_grad
global_param.copy_(param.data)

View File

@ -108,3 +108,6 @@ class FairseqOptimizer(object):
if hasattr(self.optimizer, 'supports_memory_efficient_fp16'):
return self.optimizer.supports_memory_efficient_fp16
return False
def average_params(self):
pass