diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 9eca5120a..28b60aa51 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion): super().__init__() self.padding_idx = padding_idx - def grad_denom(self, samples): - return sum(s['ntokens'] if s else 0 for s in samples) + def forward(self, model, sample): + """Compute the loss for the given sample. - def forward(self, model, sample, grad_denom): + Returns a tuple with three elements: + 1) the loss, as a Variable + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ net_output = model(**sample['net_input']) input = net_output.view(-1, net_output.size(-1)) target = sample['target'].view(-1) loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) - return { - 'loss': loss / grad_denom, + sample_size = sample['ntokens'] + logging_output = { + 'loss': loss.data[0], + 'sample_size': sample_size, } + return loss, sample_size, logging_output - def aggregate(self, loss_dicts): + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) return { - 'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2), + 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), } diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index 2fbf37464..dd851e827 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -14,14 +14,22 @@ class FairseqCriterion(_Loss): def __init__(self): super().__init__() - def grad_denom(self, samples): - """Gradient normalization term for DataParallel training.""" + def forward(self, model, sample): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss, as a Variable + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ raise NotImplementedError - def forward(self, model, sample, grad_denom): - """Compute the loss for the given sample and network output.""" + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" raise NotImplementedError - def aggregate(self, losses, log_infos): - """Aggregate losses from DataParallel training.""" - raise NotImplementedError + @staticmethod + def grad_denom(sample_sizes): + """Compute the gradient denominator for a set of sample sizes.""" + return sum(sample_sizes) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index a60c83802..008935c55 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): self.padding_idx = padding_idx self.weights = weights - def grad_denom(self, samples): - return sum(s['ntokens'] if s else 0 for s in samples) + def forward(self, model, sample): + """Compute the loss for the given sample. - def forward(self, model, sample, grad_denom): + Returns a tuple with three elements: + 1) the loss, as a Variable + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ net_output = model(**sample['net_input']) input = F.log_softmax(net_output.view(-1, net_output.size(-1))) target = sample['target'].view(-1) loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights) - return { - 'loss': loss / grad_denom, + sample_size = sample['ntokens'] + logging_output = { + 'loss': loss.data[0], + 'sample_size': sample_size, } + return loss, sample_size, logging_output - def aggregate(self, loss_dicts): + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) return { - 'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2), + 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), } diff --git a/fairseq/multiprocessing_trainer.py b/fairseq/multiprocessing_trainer.py index 2b8a1716c..ed1cf3422 100644 --- a/fairseq/multiprocessing_trainer.py +++ b/fairseq/multiprocessing_trainer.py @@ -15,7 +15,6 @@ import torch from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from fairseq import nccl, utils -from fairseq.criterions import FairseqCriterion from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.nag import NAG @@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): momentum=self.args.momentum, weight_decay=self.args.weight_decay) self.flat_grads = None + self.loss = None # initialize LR scheduler self.lr_scheduler = self._build_lr_scheduler() @@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): # scatter sample across GPUs self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) - # calculate gradient normalization term - grad_denom = self.criterion.grad_denom(samples) - - # forward pass, backward pass and gradient step - losses = [ - self.call_async(rank, '_async_train_step', grad_denom=grad_denom) + # forward pass + sample_sizes, logging_outputs = Future.gen_tuple_list([ + self.call_async(rank, '_async_forward') for rank in range(self.num_replicas) - ] + ]) - # aggregate losses and gradient norms - loss_dicts = Future.gen_list(losses) - loss_dict = self.criterion.aggregate(loss_dicts) - loss_dict['gnorm'] = loss_dicts[0]['gnorm'] + # backward pass, all-reduce gradients and take an optimization step + grad_denom = self.criterion.__class__.grad_denom(sample_sizes) + grad_norms = Future.gen_list([ + self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) + for rank in range(self.num_replicas) + ]) - return loss_dict + # aggregate logging output + logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) + logging_output['gnorm'] = grad_norms[0] # log the gradient norm - def _async_train_step(self, rank, device_id, grad_denom): - self.model.train() + return logging_output - # zero grads even if self._sample is None, since we will all-reduce them - self.optimizer.zero_grad() + def _async_forward(self, rank, device_id, eval=False): + if eval: + self.model.eval() + else: + self.model.train() + self.optimizer.zero_grad() - # calculate loss and grads - loss = 0 - loss_dict = {} - if self._sample is not None: - loss_dict = self.criterion(self.model, self._sample, grad_denom) - loss_dict['loss'].backward() - loss = loss_dict['loss'].data[0] + if self._sample is None: + return 0, {} + + # calculate loss and sample size + self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) + + return sample_size, logging_output + + def _async_backward_and_opt(self, rank, device_id, grad_denom): + if self.loss is not None: + # backward pass + self.loss.backward() # flatten grads into a contiguous block of memory if self.flat_grads is None: @@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): # all-reduce grads nccl.all_reduce(self.flat_grads) + # normalize grads + if grad_denom != 0: + self.flat_grads.div_(grad_denom) + # clip grads - loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm) + grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm) # take an optimization step self.optimizer.step() - return loss_dict + # reset loss + self.loss = None + + return grad_norm def _flatten_grads_(self, model): num_params = sum(p.data.numel() for p in model.parameters()) @@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): # scatter sample across GPUs self._scatter_samples(samples, volatile=True) - # calculate gradient normalization term - grad_denom = self.criterion.grad_denom(samples) - # forward pass - losses = [ - self.call_async(rank, '_async_valid_step', grad_denom=grad_denom) + _sample_sizes, logging_outputs = Future.gen_tuple_list([ + self.call_async(rank, '_async_forward', eval=True) for rank in range(self.num_replicas) - ] + ]) - # aggregate losses - loss_dict = self.criterion.aggregate(Future.gen_list(losses)) + # aggregate logging output + logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) - return loss_dict - - def _async_valid_step(self, rank, device_id, grad_denom): - if self._sample is None: - return {} - self.model.eval() - return self.criterion(self.model, self._sample, grad_denom) + return logging_output def get_lr(self): """Get the current learning rate."""