mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-22 06:39:29 +03:00
More flexible gradient normalization
This commit is contained in:
parent
88a8bd42c8
commit
3f9700868d
@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
|
||||
super().__init__()
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
def grad_denom(self, samples):
|
||||
return sum(s['ntokens'] if s else 0 for s in samples)
|
||||
def forward(self, model, sample):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
def forward(self, model, sample, grad_denom):
|
||||
Returns a tuple with three elements:
|
||||
1) the loss, as a Variable
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample['net_input'])
|
||||
input = net_output.view(-1, net_output.size(-1))
|
||||
target = sample['target'].view(-1)
|
||||
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
|
||||
return {
|
||||
'loss': loss / grad_denom,
|
||||
sample_size = sample['ntokens']
|
||||
logging_output = {
|
||||
'loss': loss.data[0],
|
||||
'sample_size': sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def aggregate(self, loss_dicts):
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
|
||||
return {
|
||||
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
|
||||
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
|
||||
}
|
||||
|
@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def grad_denom(self, samples):
|
||||
"""Gradient normalization term for DataParallel training."""
|
||||
def forward(self, model, sample):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss, as a Variable
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, model, sample, grad_denom):
|
||||
"""Compute the loss for the given sample and network output."""
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, losses, log_infos):
|
||||
"""Aggregate losses from DataParallel training."""
|
||||
raise NotImplementedError
|
||||
@staticmethod
|
||||
def grad_denom(sample_sizes):
|
||||
"""Compute the gradient denominator for a set of sample sizes."""
|
||||
return sum(sample_sizes)
|
||||
|
@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
self.padding_idx = padding_idx
|
||||
self.weights = weights
|
||||
|
||||
def grad_denom(self, samples):
|
||||
return sum(s['ntokens'] if s else 0 for s in samples)
|
||||
def forward(self, model, sample):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
def forward(self, model, sample, grad_denom):
|
||||
Returns a tuple with three elements:
|
||||
1) the loss, as a Variable
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample['net_input'])
|
||||
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
|
||||
target = sample['target'].view(-1)
|
||||
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
|
||||
return {
|
||||
'loss': loss / grad_denom,
|
||||
sample_size = sample['ntokens']
|
||||
logging_output = {
|
||||
'loss': loss.data[0],
|
||||
'sample_size': sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def aggregate(self, loss_dicts):
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
|
||||
return {
|
||||
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
|
||||
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
||||
|
||||
from fairseq import nccl, utils
|
||||
from fairseq.criterions import FairseqCriterion
|
||||
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
|
||||
from fairseq.nag import NAG
|
||||
|
||||
@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
momentum=self.args.momentum,
|
||||
weight_decay=self.args.weight_decay)
|
||||
self.flat_grads = None
|
||||
self.loss = None
|
||||
|
||||
# initialize LR scheduler
|
||||
self.lr_scheduler = self._build_lr_scheduler()
|
||||
@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
# scatter sample across GPUs
|
||||
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
|
||||
|
||||
# calculate gradient normalization term
|
||||
grad_denom = self.criterion.grad_denom(samples)
|
||||
|
||||
# forward pass, backward pass and gradient step
|
||||
losses = [
|
||||
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
|
||||
# forward pass
|
||||
sample_sizes, logging_outputs = Future.gen_tuple_list([
|
||||
self.call_async(rank, '_async_forward')
|
||||
for rank in range(self.num_replicas)
|
||||
]
|
||||
])
|
||||
|
||||
# aggregate losses and gradient norms
|
||||
loss_dicts = Future.gen_list(losses)
|
||||
loss_dict = self.criterion.aggregate(loss_dicts)
|
||||
loss_dict['gnorm'] = loss_dicts[0]['gnorm']
|
||||
# backward pass, all-reduce gradients and take an optimization step
|
||||
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
|
||||
grad_norms = Future.gen_list([
|
||||
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
|
||||
for rank in range(self.num_replicas)
|
||||
])
|
||||
|
||||
return loss_dict
|
||||
# aggregate logging output
|
||||
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
|
||||
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
|
||||
|
||||
def _async_train_step(self, rank, device_id, grad_denom):
|
||||
self.model.train()
|
||||
return logging_output
|
||||
|
||||
# zero grads even if self._sample is None, since we will all-reduce them
|
||||
self.optimizer.zero_grad()
|
||||
def _async_forward(self, rank, device_id, eval=False):
|
||||
if eval:
|
||||
self.model.eval()
|
||||
else:
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# calculate loss and grads
|
||||
loss = 0
|
||||
loss_dict = {}
|
||||
if self._sample is not None:
|
||||
loss_dict = self.criterion(self.model, self._sample, grad_denom)
|
||||
loss_dict['loss'].backward()
|
||||
loss = loss_dict['loss'].data[0]
|
||||
if self._sample is None:
|
||||
return 0, {}
|
||||
|
||||
# calculate loss and sample size
|
||||
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
|
||||
|
||||
return sample_size, logging_output
|
||||
|
||||
def _async_backward_and_opt(self, rank, device_id, grad_denom):
|
||||
if self.loss is not None:
|
||||
# backward pass
|
||||
self.loss.backward()
|
||||
|
||||
# flatten grads into a contiguous block of memory
|
||||
if self.flat_grads is None:
|
||||
@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
# all-reduce grads
|
||||
nccl.all_reduce(self.flat_grads)
|
||||
|
||||
# normalize grads
|
||||
if grad_denom != 0:
|
||||
self.flat_grads.div_(grad_denom)
|
||||
|
||||
# clip grads
|
||||
loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
|
||||
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
|
||||
|
||||
# take an optimization step
|
||||
self.optimizer.step()
|
||||
|
||||
return loss_dict
|
||||
# reset loss
|
||||
self.loss = None
|
||||
|
||||
return grad_norm
|
||||
|
||||
def _flatten_grads_(self, model):
|
||||
num_params = sum(p.data.numel() for p in model.parameters())
|
||||
@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
# scatter sample across GPUs
|
||||
self._scatter_samples(samples, volatile=True)
|
||||
|
||||
# calculate gradient normalization term
|
||||
grad_denom = self.criterion.grad_denom(samples)
|
||||
|
||||
# forward pass
|
||||
losses = [
|
||||
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
|
||||
_sample_sizes, logging_outputs = Future.gen_tuple_list([
|
||||
self.call_async(rank, '_async_forward', eval=True)
|
||||
for rank in range(self.num_replicas)
|
||||
]
|
||||
])
|
||||
|
||||
# aggregate losses
|
||||
loss_dict = self.criterion.aggregate(Future.gen_list(losses))
|
||||
# aggregate logging output
|
||||
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
|
||||
|
||||
return loss_dict
|
||||
|
||||
def _async_valid_step(self, rank, device_id, grad_denom):
|
||||
if self._sample is None:
|
||||
return {}
|
||||
self.model.eval()
|
||||
return self.criterion(self.model, self._sample, grad_denom)
|
||||
return logging_output
|
||||
|
||||
def get_lr(self):
|
||||
"""Get the current learning rate."""
|
||||
|
Loading…
Reference in New Issue
Block a user