More flexible gradient normalization

This commit is contained in:
Myle Ott 2017-10-06 17:26:33 -04:00
parent 88a8bd42c8
commit 3f9700868d
4 changed files with 97 additions and 62 deletions

View File

@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
super().__init__()
self.padding_idx = padding_idx
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, model, sample):
"""Compute the loss for the given sample.
def forward(self, model, sample, grad_denom):
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}

View File

@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
def __init__(self):
super().__init__()
def grad_denom(self, samples):
"""Gradient normalization term for DataParallel training."""
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
def forward(self, model, sample, grad_denom):
"""Compute the loss for the given sample and network output."""
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError
def aggregate(self, losses, log_infos):
"""Aggregate losses from DataParallel training."""
raise NotImplementedError
@staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
return sum(sample_sizes)

View File

@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.padding_idx = padding_idx
self.weights = weights
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, model, sample):
"""Compute the loss for the given sample.
def forward(self, model, sample, grad_denom):
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}

View File

@ -15,7 +15,6 @@ import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq.criterions import FairseqCriterion
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.flat_grads = None
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
# forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
]
])
# aggregate losses and gradient norms
loss_dicts = Future.gen_list(losses)
loss_dict = self.criterion.aggregate(loss_dicts)
loss_dict['gnorm'] = loss_dicts[0]['gnorm']
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
return loss_dict
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
def _async_train_step(self, rank, device_id, grad_denom):
self.model.train()
return logging_output
# zero grads even if self._sample is None, since we will all-reduce them
self.optimizer.zero_grad()
def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
# calculate loss and grads
loss = 0
loss_dict = {}
if self._sample is not None:
loss_dict = self.criterion(self.model, self._sample, grad_denom)
loss_dict['loss'].backward()
loss = loss_dict['loss'].data[0]
if self._sample is None:
return 0, {}
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
return sample_size, logging_output
def _async_backward_and_opt(self, rank, device_id, grad_denom):
if self.loss is not None:
# backward pass
self.loss.backward()
# flatten grads into a contiguous block of memory
if self.flat_grads is None:
@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# all-reduce grads
nccl.all_reduce(self.flat_grads)
# normalize grads
if grad_denom != 0:
self.flat_grads.div_(grad_denom)
# clip grads
loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
# take an optimization step
self.optimizer.step()
return loss_dict
# reset loss
self.loss = None
return grad_norm
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass
losses = [
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
_sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
]
])
# aggregate losses
loss_dict = self.criterion.aggregate(Future.gen_list(losses))
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
return loss_dict
def _async_valid_step(self, rank, device_id, grad_denom):
if self._sample is None:
return {}
self.model.eval()
return self.criterion(self.model, self._sample, grad_denom)
return logging_output
def get_lr(self):
"""Get the current learning rate."""