added fast stats sync option (#858)

Summary:
Added `--fast-stat-sync` option.
This avoids pickle and achieves `~7%` more `wps` on 16 nodes.
It is less flexible as it just aggregates only basic stats and it ignores the aggregate function defined by criterion.

Let me know what you think myleott
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/858

Differential Revision: D17398770

fbshipit-source-id: 36261a1d970e67deeda8211af8f009ef9b4f9c14
This commit is contained in:
Naman Goyal 2019-09-16 12:52:05 -07:00 committed by Facebook Github Bot
parent 1fd8943e94
commit e1ba32aae2
4 changed files with 66 additions and 11 deletions

View File

@ -30,6 +30,7 @@ class CrossEntropyCriterion(FairseqCriterion):
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,

View File

@ -47,6 +47,7 @@ class MaskedLmLoss(FairseqCriterion):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,

View File

@ -332,6 +332,9 @@ def add_distributed_training_args(parser):
group.add_argument('--find-unused-parameters', default=False, action='store_true',
help='disable unused parameter detection (not applicable to '
'no_c10d ddp-backend')
group.add_argument('--fast-stat-sync', default=False, action='store_true',
help='Enable fast sync of stats between nodes, this hardcodes to '
'sync only some default stats from logging_output.')
# fmt: on
return group

View File

@ -57,6 +57,11 @@ class Trainer(object):
self._wrapped_criterion = None
self._wrapped_model = None
# Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
# It is less flexible and syncs only the default stats.
self._all_reduce_list = [0.0] * 6
self.fast_stat_sync = args.fast_stat_sync
self.init_meters(args)
def init_meters(self, args):
@ -292,6 +297,13 @@ class Trainer(object):
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)
if self.fast_stat_sync:
self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
self._all_reduce_list[2] += logging_output.get('loss', 0.0)
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
except RuntimeError as e:
if 'out of memory' in str(e):
msg = (
@ -311,6 +323,10 @@ class Trainer(object):
else:
raise e
if self.fast_stat_sync:
self._all_reduce_list[5] += ooms
if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)
@ -318,13 +334,30 @@ class Trainer(object):
return None
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1 and (
(not self.args.use_bmuf)
or (
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
if self.fast_stat_sync:
# rework all_gather_list
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
if self._sync_stats():
torch.distributed.all_reduce(all_reduce_list_tensor)
# Normalize loss and nll_loss by "sample_size"
# and convert to log base 2
all_reduce_list_tensor[2:4].div_(
(
all_reduce_list_tensor[0:1] *
torch.log(torch.cuda.DoubleTensor([2]))
)
)
):
self._all_reduce_list = all_reduce_list_tensor.tolist()
logging_output = {}
[
sample_size,
logging_output['nsentences'],
logging_output['loss'],
logging_output['nll_loss'],
logging_output['ntokens'],
ooms,
] = self._all_reduce_list
elif self._sync_stats():
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
@ -345,11 +378,12 @@ class Trainer(object):
self.zero_grad()
return None
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not self.fast_stat_sync:
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
@ -400,6 +434,7 @@ class Trainer(object):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
self.clear_buffered_stats()
self.meters['train_wall'].stop()
return logging_output
@ -484,6 +519,9 @@ class Trainer(object):
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._all_reduce_list = [0.0] * 6
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
self.lr_scheduler.step(epoch, val_loss)
@ -545,3 +583,15 @@ class Trainer(object):
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
def _sync_stats(self):
return (
self.args.distributed_world_size > 1 and
(
(not self.args.use_bmuf) or
(
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
)
)