diff --git a/distributed_train.py b/distributed_train.py index c3e137bc2..f3725a6e1 100644 --- a/distributed_train.py +++ b/distributed_train.py @@ -30,7 +30,7 @@ def main(args): raise e except FileNotFoundError as e: # Slurm is not installed pass - if args.distributed_init_method is None: + if args.distributed_init_method is None and args.distributed_port is None: raise ValueError('--distributed-init-method or --distributed-port ' 'must be specified for distributed training') diff --git a/docs/conf.py b/docs/conf.py index ac733584f..3888492c8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' # built documents. # # The short X.Y version. -version = '0.5.0' +version = '0.6.0' # The full version, including alpha/beta/rc tags. -release = '0.5.0' +release = '0.6.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/data.rst b/docs/data.rst index 24a123628..fed632780 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -36,5 +36,7 @@ Iterators :members: .. autoclass:: fairseq.data.EpochBatchIterator :members: +.. autoclass:: fairseq.data.GroupedIterator + :members: .. autoclass:: fairseq.data.ShardedIterator :members: diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index e8529fd42..c684565c9 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion): logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } return loss, sample_size, logging_output @@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion): """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get('loss', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) agg_output = { 'loss': loss_sum / sample_size / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, 'sample_size': sample_size, } if sample_size != ntokens: diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 152b5a613..8c94265f5 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion): logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } return loss, sample_size, logging_output @@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion): """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get('loss', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) agg_output = { 'loss': loss_sum / sample_size / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, 'sample_size': sample_size, } if sample_size != ntokens: diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 970f7b9ce..e95b06674 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 'loss': utils.item(loss.data) if reduce else loss.data, 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } return loss, sample_size, logging_output @@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss return loss, nll_loss - @staticmethod def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) return { 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, 'sample_size': sample_size, } diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 7d80c44a9..c04c32487 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset from .monolingual_dataset import MonolingualDataset from .token_block_dataset import TokenBlockDataset -from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator +from .iterators import ( + CountingIterator, + EpochBatchIterator, + GroupedIterator, + ShardedIterator, +) __all__ = [ 'CountingIterator', 'Dictionary', 'EpochBatchIterator', 'FairseqDataset', + 'GroupedIterator', 'IndexedDataset', 'IndexedInMemoryDataset', 'IndexedRawTextDataset', 'LanguagePairDataset', 'MonolingualDataset', - 'TokenBlockDataset', 'ShardedIterator', + 'TokenBlockDataset', ] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 073dab600..31331baa0 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -6,6 +6,7 @@ # can be found in the PATENTS file in the same directory. import itertools +import math import numpy as np import torch @@ -150,6 +151,36 @@ class EpochBatchIterator(object): )) +class GroupedIterator(object): + """Wrapper around an iterable that returns groups (chunks) of items. + + Args: + iterable (iterable): iterable to wrap + chunk_size (int): size of each chunk + """ + + def __init__(self, iterable, chunk_size): + self._len = int(math.ceil(len(iterable) / float(chunk_size))) + self.itr = iter(iterable) + self.chunk_size = chunk_size + + def __len__(self): + return self._len + + def __iter__(self): + return self + + def __next__(self): + chunk = [] + try: + for _ in range(self.chunk_size): + chunk.append(next(self.itr)) + except StopIteration as e: + if len(chunk) == 0: + raise e + return chunk + + class ShardedIterator(object): """A sharded wrapper around an iterable, padded to length. diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index a14e522a5..0ff0d0e5c 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -7,7 +7,9 @@ import pickle -import torch.distributed +import torch +from torch import distributed +from torch.distributed import group from fairseq import utils @@ -16,22 +18,39 @@ def is_master(args): return args.distributed_rank == 0 +_use_c10d = [None] + + def distributed_init(args): if args.distributed_world_size == 1: raise ValueError('Cannot initialize distributed with distributed_world_size=1') + if _use_c10d[0] is None: + _use_c10d[0] = not args.no_c10d + + if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'): + _use_c10d[0] = False + print('WARNING: cannot find DistributedDataParallelC10d, ' + 'falling back to standard DistributedDataParallel') + print('| distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method), flush=True) - if args.distributed_init_method.startswith('tcp://'): - torch.distributed.init_process_group( - backend=args.distributed_backend, init_method=args.distributed_init_method, - world_size=args.distributed_world_size, rank=args.distributed_rank) - else: - torch.distributed.init_process_group( - backend=args.distributed_backend, init_method=args.distributed_init_method, - world_size=args.distributed_world_size) - args.distributed_rank = torch.distributed.get_rank() + if _use_c10d[0]: + distributed.c10d.init_process_group( + backend=args.distributed_backend, + init_method=args.distributed_init_method, + world_size=args.distributed_world_size, + rank=args.distributed_rank, + ) + else: + distributed.init_process_group( + backend=args.distributed_backend, + init_method=args.distributed_init_method, + world_size=args.distributed_world_size, + rank=args.distributed_rank, + ) + if not is_master(args): suppress_output() @@ -52,35 +71,77 @@ def suppress_output(): __builtin__.print = print -def all_gather_list(data, max_size=16384): - """Gathers arbitrary data from all nodes into a list.""" - world_size = torch.distributed.get_world_size() - if not hasattr(all_gather_list, '_in_buffer') or \ - max_size != all_gather_list._in_buffer.size(): - all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) - all_gather_list._out_buffers = [ - torch.cuda.ByteTensor(max_size) - for i in range(world_size) - ] - in_buffer = all_gather_list._in_buffer - out_buffers = all_gather_list._out_buffers +def get_rank(): + if _use_c10d[0]: + return distributed.c10d.get_rank() + else: + return distributed.get_rank() + + +def get_world_size(): + if _use_c10d[0]: + return distributed.c10d.get_world_size() + else: + return distributed.get_world_size() + + +def get_default_group(): + if _use_c10d[0]: + return distributed.c10d.group.WORLD + else: + return distributed.group.WORLD + + +def all_reduce(tensor, group=None): + if group is None: + group = get_default_group() + if _use_c10d[0]: + return distributed.c10d.all_reduce(tensor, group=group) + else: + return distributed.all_reduce(tensor, group=group) + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable. + + Args: + data (Any): data from the local worker to be gathered on other workers + group (optional): group of the collective + max_size (int, optional): maximum size of the data to be gathered + across workers + """ + rank = get_rank() + world_size = get_world_size() + + buffer_size = max_size * world_size + if not hasattr(all_gather_list, '_buffer') or \ + all_gather_list._buffer.numel() < buffer_size: + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + buffer = all_gather_list._buffer + buffer.zero_() enc = pickle.dumps(data) enc_size = len(enc) if enc_size + 2 > max_size: raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) assert max_size < 255*256 - in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k - in_buffer[1] = enc_size % 255 - in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) - torch.distributed.all_gather(out_buffers, in_buffer.cuda()) + buffer_rank = buffer[rank * max_size : (rank + 1) * max_size] + buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k + buffer_rank[1] = enc_size % 255 + buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc)) + + all_reduce(buffer, group=group) result = [] for i in range(world_size): - out_buffer = out_buffers[i] + out_buffer = buffer[i * max_size : (i + 1) * max_size] size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) - result.append( - pickle.loads(bytes(out_buffer[2:size+2].tolist())) - ) + if size > 0: + result.append( + pickle.loads(bytes(out_buffer[2:size+2].tolist())) + ) return result diff --git a/fairseq/fp16_trainer.py b/fairseq/fp16_trainer.py deleted file mode 100644 index 5ae88ab1e..000000000 --- a/fairseq/fp16_trainer.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -""" -Train a network on multiple GPUs. -""" - -import torch - -from fairseq import optim, utils -from fairseq.meters import AverageMeter -from fairseq.optim import lr_scheduler -from fairseq.trainer import Trainer - - -class DynamicLossScaler: - - def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000): - self.loss_scale = init_scale - self.scale_factor = scale_factor - self.scale_window = scale_window - self._iter = 0 - self._last_overflow_iter = -1 - - def update_scale(self, overflow): - if overflow: - self.loss_scale /= self.scale_factor - self._last_overflow_iter = self._iter - elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: - self.loss_scale *= self.scale_factor - self._iter += 1 - - @staticmethod - def has_overflow(grad_norm): - # detect inf and nan - if grad_norm == float('inf') or grad_norm != grad_norm: - return True - return False - - -class FP16Trainer(Trainer): - """Modified trainer for FP16. - - We maintain two copies of the model's parameters, both in FP16 and FP32. - We do forward/backward with FP16 and compute the loss + optimize with FP32. - """ - - def __init__(self, args, task, model, criterion): - super().__init__(args, task, model, criterion) - - # convert model to FP16 (but keep criterion FP32) - self.model.half() - - # dynamically scale loss to reduce overflow - self.scaler = DynamicLossScaler(init_scale=2.**7) - self.meters['loss_scale'] = AverageMeter() - - def _build_optimizer(self): - # create FP32 copy of parameters and grads - params = [p for p in self.model.parameters() if p.requires_grad] - total_param_size = sum(p.data.numel() for p in params) - self.fp32_params = params[0].new(0).float().new(total_param_size) - offset = 0 - for p in params: - numel = p.data.numel() - self.fp32_params[offset:offset+numel].copy_(p.data.view(-1)) - offset += numel - self.fp32_params = torch.nn.Parameter(self.fp32_params) - self.fp32_params.grad = self.fp32_params.data.new(total_param_size) - - # create optimizer using the copied FP32 params - self._optimizer = optim.build_optimizer(self.args, [self.fp32_params]) - self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) - - def save_checkpoint(self, filename, extra_state): - """Save all training state in a checkpoint file.""" - extra_state['loss_scale'] = self.scaler.loss_scale - super().save_checkpoint(filename, extra_state) - - def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): - """Load all training state from a checkpoint file.""" - extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides) - if extra_state is not None and 'loss_scale' in extra_state: - self.scaler.loss_scale = extra_state['loss_scale'] - return extra_state - - def zero_grad(self): - # zero both the FP16 and FP32 grads - self.model.zero_grad() # FP16 - self.optimizer.zero_grad() # FP32 - - def _backward(self, loss): - self.meters['loss_scale'].reset() - self.meters['loss_scale'].update(self.scaler.loss_scale) - if loss is not None: - # dynamically rescale loss to stay in FP16 range - loss = loss * self.scaler.loss_scale - return super()._backward(loss) - - def _all_reduce_and_rescale(self, grad_denom): - # undo effect of dynamic loss scaling on gradients - grad_denom *= self.scaler.loss_scale - - if self.args.distributed_world_size > 1: - # flatten grads into a single buffer - flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) - - # scale gradients to avoid overflow in all-reduce - flat_grads.div_(self.args.distributed_world_size) - grad_denom /= self.args.distributed_world_size - - # all-reduce flat grads - torch.distributed.all_reduce(flat_grads) - - # copy grads back to FP32 - self.fp32_params.grad.data.copy_(flat_grads) - else: - # single worker: copy grads directly to FP32 - self._get_flat_grads(out=self.fp32_params.grad.data) - - # rescale and clip grads - self.fp32_params.grad.data.div_(grad_denom) - grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) - - # detect overflow and adjust loss scale - overflow = DynamicLossScaler.has_overflow(grad_norm) - self.scaler.update_scale(overflow) - if overflow: - if self.scaler.loss_scale <= self.args.min_loss_scale: - raise Exception(( - 'Minimum loss scale reached ({}). Your loss is probably exploding. ' - 'Try lowering the learning rate, using gradient clipping or ' - 'increasing the batch size.' - ).format(self.args.min_loss_scale)) - raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) - - return grad_norm - - def _opt(self): - # take an optimization step using the FP32 params and grads - super()._opt() - - # copy FP32 params back into FP16 model - offset = 0 - for p in self.model.parameters(): - if not p.requires_grad: - continue - numel = p.data.numel() - p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) - offset += numel diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index b019fbb52..3a8646075 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401 +from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401 MODEL_REGISTRY = {} diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py new file mode 100644 index 000000000..bf4c911d3 --- /dev/null +++ b/fairseq/models/distributed_fairseq_model.py @@ -0,0 +1,62 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from torch.distributed import c10d +from torch.nn import parallel + +from . import BaseFairseqModel + + +class DistributedFairseqModel(BaseFairseqModel): + """ + A wrapper around a :class:`BaseFairseqModel` instance that adds support for + distributed training. + + Anytime a method or attribute is called on this class we first try to + forward it to the underlying DistributedDataParallel instance, otherwise we + forward it to the original :class:`BaseFairseqModel` instance. + + Args: + args (argparse.Namespace): fairseq args + model (BaseFairseqModel): model to wrap + """ + + def __init__(self, args, model): + super().__init__() + assert isinstance(model, BaseFairseqModel) + if args.no_c10d: + self.ddp_model = parallel.DistributedDataParallel( + module=model, + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=False, + ) + else: + self.ddp_model = parallel._DistributedDataParallelC10d( + module=model, + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=False, + bucket_cap_mb=args.c10d_bucket_cap_mb, + ) + + def __call__(self, *args, **kwargs): + return self.ddp_model(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.ddp_model.forward(*args, **kwargs) + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + pass + try: + return self.ddp_model.__getattr__(name) + except AttributeError: + pass + return self.ddp_model.module.__getattr__(name) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 91934dce6..2424d7282 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -9,6 +9,7 @@ import importlib import os from .fairseq_optimizer import FairseqOptimizer +from .fp16_optimizer import FP16Optimizer OPTIMIZER_REGISTRY = {} @@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set() def build_optimizer(args, params): - params = filter(lambda p: p.requires_grad, params) + params = list(filter(lambda p: p.requires_grad, params)) return OPTIMIZER_REGISTRY[args.optimizer](args, params) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 91760582a..984956d45 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -5,7 +5,9 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import torch.optim +import math + +import torch class FairseqOptimizer(object): @@ -13,7 +15,7 @@ class FairseqOptimizer(object): def __init__(self, args, params): super().__init__() self.args = args - self.params = params + self.params = list(params) @staticmethod def add_args(parser): @@ -67,10 +69,25 @@ class FairseqOptimizer(object): for group in self.optimizer.param_groups: group.update(optimizer_overrides) + def backward(self, loss): + loss.backward() + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + for p in self.params: + p.grad.data.mul_(c) + + def clip_grad_norm(self, max_norm): + """Clips gradient norm.""" + if max_norm > 0: + return torch.nn.utils.clip_grad_norm_(self.params, max_norm) + else: + return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params)) + def step(self, closure=None): """Performs a single optimization step.""" - return self.optimizer.step(closure) + self.optimizer.step(closure) def zero_grad(self): """Clears the gradients of all optimized parameters.""" - return self.optimizer.zero_grad() + self.optimizer.zero_grad() diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py new file mode 100644 index 000000000..a7293d0cd --- /dev/null +++ b/fairseq/optim/fp16_optimizer.py @@ -0,0 +1,164 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch + +from fairseq import optim, utils + + +class DynamicLossScaler: + + def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000): + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self._iter = 0 + self._last_overflow_iter = -1 + + def update_scale(self, overflow): + if overflow: + self.loss_scale /= self.scale_factor + self._last_overflow_iter = self._iter + elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: + self.loss_scale *= self.scale_factor + self._iter += 1 + + @staticmethod + def has_overflow(grad_norm): + # detect inf and nan + if grad_norm == float('inf') or grad_norm != grad_norm: + return True + return False + + +class FP16Optimizer(optim.FairseqOptimizer): + + def __init__(self, args, params, fp32_optimizer, fp32_params): + super().__init__(args, params) + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + self.scaler = DynamicLossScaler( + init_scale=2.**7, + scale_window=(2**14 / args.distributed_world_size), + ) + + @staticmethod + def build_optimizer(args, params): + # create FP32 copy of parameters and grads + total_param_size = sum(p.data.numel() for p in params) + fp32_params = params[0].new(0).float().new(total_param_size) + offset = 0 + for p in params: + numel = p.data.numel() + fp32_params[offset:offset+numel].copy_(p.data.view(-1)) + offset += numel + fp32_params = torch.nn.Parameter(fp32_params) + fp32_params.grad = fp32_params.data.new(total_param_size) + + fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + return FP16Optimizer(args, params, fp32_optimizer, fp32_params) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + state_dict['loss_scale'] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if 'loss_scale' in state_dict: + self.scaler.loss_scale = state_dict['loss_scale'] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + loss = loss * self.scaler.loss_scale + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self, multiply_grads=1.): + if self._needs_sync: + # copy FP16 grads to FP32 + offset = 0 + for p in self.params: + if not p.requires_grad: + continue + numel = p.grad.data.numel() + self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1)) + offset += numel + + # correct for dynamic loss scaler + self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale) + + self._needs_sync = False + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + if self._needs_sync: + self._sync_fp16_grads_to_fp32(c) + else: + self.fp32_params.grad.data.mul_(c) + + def clip_grad_norm(self, max_norm): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm) + + # detect overflow and adjust loss scale + overflow = DynamicLossScaler.has_overflow(grad_norm) + self.scaler.update_scale(overflow) + if overflow: + if self.scaler.loss_scale <= self.args.min_loss_scale: + raise Exception(( + 'Minimum loss scale reached ({}). Your loss is probably exploding. ' + 'Try lowering the learning rate, using gradient clipping or ' + 'increasing the batch size.' + ).format(self.args.min_loss_scale)) + raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + self.fp32_optimizer.step(closure) + + # copy FP32 params back into FP16 model + offset = 0 + for p in self.params: + if not p.requires_grad: + continue + numel = p.data.numel() + p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) + offset += numel + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.fp32_optimizer.zero_grad() + for p in self.params: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + self._needs_sync = False diff --git a/fairseq/options.py b/fairseq/options.py index 70fac7ca7..5ae694826 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -183,6 +183,10 @@ def add_distributed_training_args(parser): help='port number (not required if using --distributed-init-method)') group.add_argument('--device-id', default=0, type=int, help='which GPU to use (usually configured automatically)') + group.add_argument('--no-c10d', action='store_true', + help='don\'t use c10d distributed backend') + group.add_argument('--c10d-bucket-cap-mb', default=150, metavar='MB', + help='bucket size for c10d backend') return group diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 56ba63420..07c5efcfa 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -15,7 +15,7 @@ from itertools import chain import torch -from fairseq import distributed_utils, optim, utils +from fairseq import distributed_utils, models, optim, utils from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.optim import lr_scheduler @@ -23,22 +23,27 @@ from fairseq.optim import lr_scheduler class Trainer(object): """Main class for data parallel training. - This class supports data parallel training, where multiple workers each - have a full model replica and gradients are accumulated synchronously via - torch.distributed.all_reduce. + This class supports synchronous distributed data parallel training, + where multiple workers each have a full model replica and gradients + are accumulated across workers before each update. We use + :class:`~torch.nn.parallel.DistributedDataParallel` to handle + communication of the gradients across workers. """ - def __init__(self, args, task, model, criterion): + def __init__(self, args, task, model, criterion, dummy_batch): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args + self.task = task # copy model and criterion to current device - self.task = task - self.model = model.cuda() self.criterion = criterion.cuda() + if args.fp16: + self._model = model.half().cuda() + else: + self._model = model.cuda() # initialize meters self.meters = OrderedDict() @@ -53,14 +58,27 @@ class Trainer(object): self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory + if args.fp16: + self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds - self._buffered_stats = defaultdict(lambda: []) - self._flat_grads = None + self._dummy_batch = dummy_batch self._num_updates = 0 self._optim_history = None self._optimizer = None + self._wrapped_model = None + + @property + def model(self): + if self._wrapped_model is None: + if self.args.distributed_world_size > 1: + self._wrapped_model = models.DistributedFairseqModel( + self.args, self._model, + ) + else: + self._wrapped_model = self._model + return self._wrapped_model @property def optimizer(self): @@ -69,7 +87,17 @@ class Trainer(object): return self._optimizer def _build_optimizer(self): - self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) + if self.args.fp16: + if torch.cuda.get_device_capability(0)[0] < 7: + print('| WARNING: your device does NOT support faster training with --fp16, ' + 'please switch to FP32 which is likely to be faster') + params = list(filter(lambda p: p.requires_grad, self.model.parameters())) + self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + else: + if torch.cuda.get_device_capability(0)[0] >= 7: + print('| NOTICE: your device may support faster training with --fp16') + self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) + self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer) def save_checkpoint(self, filename, extra_state): @@ -77,31 +105,27 @@ class Trainer(object): if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters utils.save_state( - filename, self.args, self.model, self.criterion, self.optimizer, + filename, self.args, self.get_model(), self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, ) def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = \ - utils.load_model_state(filename, self.model) - + utils.load_model_state(filename, self.get_model()) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] - assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ 'criterion does not match; please reset the optimizer (--reset-optimizer)' - assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 'optimizer does not match; please reset the optimizer (--reset-optimizer)' if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) - self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self._num_updates = last_optim['num_updates'] @@ -117,7 +141,7 @@ class Trainer(object): return extra_state - def train_step(self, sample, update_params=True, dummy_batch=False): + def train_step(self, samples, dummy_batch=False): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints @@ -125,230 +149,164 @@ class Trainer(object): torch.manual_seed(seed) torch.cuda.manual_seed(seed) + self.model.train() + self.zero_grad() + if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass - sample = self._prepare_sample(sample) - loss, sample_size, logging_output, oom_fwd = self._forward(sample) - oom_bwd = self._backward(loss) + logging_outputs, sample_sizes, ooms = [], [], 0 + for i, sample in enumerate(samples): + sample = self._prepare_sample(sample) + if sample is None: + # when sample is None, run forward/backward on a dummy batch + # and ignore the resulting gradients + sample = self._prepare_sample(self._dummy_batch) + ignore_grad = True + else: + ignore_grad = False - # buffer stats and logging outputs - self._buffered_stats['sample_sizes'].append(sample_size) - self._buffered_stats['logging_outputs'].append(logging_output) - self._buffered_stats['ooms_fwd'].append(oom_fwd) - self._buffered_stats['ooms_bwd'].append(oom_bwd) + try: + # forward + loss, sample_size, logging_output = self.task.get_loss( + self.model, self.criterion, sample, + ) + if ignore_grad: + loss *= 0 - # update parameters - if update_params: - agg_logging_output = self._update_params() - else: - agg_logging_output = None # buffering updates + if self.args.distributed_world_size > 1: + # only all-reduce gradients in the last backwards pass + if i < len(samples) - 1: + self.model.need_reduction = False + else: + self.model.need_reduction = True - if not dummy_batch: - self.meters['train_wall'].stop() + # backward + self.optimizer.backward(loss) - return agg_logging_output + if not ignore_grad: + logging_outputs.append(logging_output) + sample_sizes.append(sample_size) + except RuntimeError as e: + if 'out of memory' in str(e): + print('| WARNING: ran out of memory, skipping batch') + ooms += 1 + self.zero_grad() + else: + raise e + + if dummy_batch: + return None - def _update_params(self): # gather logging outputs from all replicas - sample_sizes = self._buffered_stats['sample_sizes'] - logging_outputs = self._buffered_stats['logging_outputs'] - ooms_fwd = self._buffered_stats['ooms_fwd'] - ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: - sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( - lambda l: list(chain.from_iterable(l)), - zip(*distributed_utils.all_gather_list( - (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd) - )) - ) - ooms_fwd = sum(ooms_fwd) - ooms_bwd = sum(ooms_bwd) + logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list( + [logging_outputs, sample_sizes, ooms], + )) + logging_outputs = list(chain.from_iterable(logging_outputs)) + sample_sizes = list(chain.from_iterable(sample_sizes)) + ooms = sum(ooms) - if ooms_fwd == self.args.distributed_world_size: - print('| WARNING: OOM in all workers, skipping batch') + if ooms == self.args.distributed_world_size: + print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None - # aggregate stats and logging outputs - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) - agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) - grad_denom = self.criterion.__class__.grad_denom(sample_sizes) + # aggregate logging outputs and sample sizes + logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) + sample_size = self.criterion.__class__.grad_denom(sample_sizes) + + if not all(k in logging_output for k in ['ntokens', 'nsentences']): + raise Exception(( + 'Please update the {}.aggregate_logging_outputs() method to ' + 'return ntokens and nsentences' + ).format(self.criterion.__class__.__name__)) try: - # all-reduce and rescale gradients, then take an optimization step - grad_norm = self._all_reduce_and_rescale(grad_denom) - self._opt() + # normalize grads by sample size + self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) + + # clip grads + grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) + + # take an optimization step + self.optimizer.step() + self._num_updates += 1 + + # update learning rate + self.lr_scheduler.step_update(self._num_updates) # update meters + ntokens = logging_output.get('ntokens', 0) + nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) - if grad_norm is not None: - self.meters['gnorm'].update(grad_norm) - self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.) - self.meters['oom'].update(ooms_fwd + ooms_bwd) - - # update loss meters for training - if 'loss' in agg_logging_output: - self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) - # criterions can optionally log the NLL loss too - if 'nll_loss' in agg_logging_output: - self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) + self.meters['gnorm'].update(grad_norm) + self.meters['clip'].update( + 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. + ) + self.meters['oom'].update(ooms) + self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) + self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: - self.zero_grad() print('| WARNING: overflow detected, ' + str(e)) + self.zero_grad() + logging_output = None - self.clear_buffered_stats() + if self.args.fp16: + self.meters['loss_scale'].reset() + self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) - return agg_logging_output + self.meters['train_wall'].stop() - def _forward(self, sample, eval=False): - loss = None - sample_size = 0 - logging_output = { - 'ntokens': sample['ntokens'] if sample is not None else 0, - 'nsentences': sample['target'].size(0) if sample is not None else 0, - } - oom = 0 - try: - # prepare model and optimizer - if eval: - self.model.eval() - else: - self.model.train() - - if sample is not None: - with torch.no_grad() if eval else contextlib.ExitStack(): - # calculate loss and sample size - loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample) - logging_output.update(logging_output_) - except RuntimeError as e: - if not eval and 'out of memory' in str(e): - print('| WARNING: ran out of memory, skipping batch') - oom = 1 - loss = None - else: - raise e - return loss, sample_size, logging_output, oom - - def _backward(self, loss): - oom = 0 - if loss is not None: - try: - # backward pass - loss.backward() - except RuntimeError as e: - if 'out of memory' in str(e): - print('| WARNING: ran out of memory, skipping batch') - oom = 1 - self.zero_grad() - else: - raise e - return oom - - def _all_reduce_and_rescale(self, grad_denom): - # flatten grads into a single buffer and all-reduce - flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) - if self.args.distributed_world_size > 1: - torch.distributed.all_reduce(flat_grads) - - # rescale and clip gradients - flat_grads.div_(grad_denom) - grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm) - - # copy grads back into model parameters - self._set_flat_grads(flat_grads) - - return grad_norm - - def _get_grads(self): - grads = [] - for name, p in self.model.named_parameters(): - if not p.requires_grad: - continue - if p.grad is None: - print('WARNING: model parameter did not receive gradient: ' + name + '. ' - 'Check that you\'re using the param in the forward pass or set requires_grad=False') - grads.append(p.new_zeros(p.shape)) - else: - grads.append(p.grad.data) - return grads - - def _get_flat_grads(self, out=None): - grads = self._get_grads() - if out is None: - grads_size = sum(g.numel() for g in grads) - out = grads[0].new(grads_size).zero_() - offset = 0 - for g in grads: - numel = g.numel() - out[offset:offset+numel].copy_(g.view(-1)) - offset += numel - return out[:offset] - - def _set_flat_grads(self, new_grads): - grads = self._get_grads() - offset = 0 - for g in grads: - numel = g.numel() - g.copy_(new_grads[offset:offset+numel].view_as(g)) - offset += numel - - def _opt(self): - # take an optimization step - self.optimizer.step() - self.zero_grad() - self._num_updates += 1 - - # update learning rate - self.lr_scheduler.step_update(self._num_updates) + return logging_output def valid_step(self, sample): """Do forward pass in evaluation mode.""" - # forward pass - sample = self._prepare_sample(sample) - _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True) - assert not oom_fwd, 'Ran out of memory during validation' + self.model.eval() - # gather logging outputs from all GPUs + logging_output, sample_size = {}, 0 + with torch.no_grad(): + sample = self._prepare_sample(sample) + if sample is None: + sample = self._prepare_sample(self._dummy_batch) + _loss, sample_size, logging_output = self.task.get_loss( + self.model, self.criterion, sample, + ) + + # gather logging outputs from all replicas if self.args.distributed_world_size > 1: - sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list( - (sample_size, logging_output) + logging_output, sample_size = zip(*distributed_utils.all_gather_list( + [logging_output, sample_size], )) + logging_output = list(logging_output) + sample_size = list(sample_size) else: - sample_sizes = [sample_size] - logging_outputs = [logging_output] + logging_output = [logging_output] + sample_size = [sample_size] - # aggregate stats and logging outputs - ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) - grad_denom = self.criterion.__class__.grad_denom(sample_sizes) - agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) + # aggregate logging outputs and sample sizes + logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_output) + sample_size = self.criterion.__class__.grad_denom(sample_size) - # update loss meters for validation - if 'loss' in agg_logging_output: - self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom) - # criterions can optionally log the NLL loss too - if 'nll_loss' in agg_logging_output: - self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) + # update meters for validation + ntokens = logging_output.get('ntokens', 0) + self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size) + self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) - return agg_logging_output + return logging_output def dummy_train_step(self, dummy_batch): """Dummy training step for warming caching allocator.""" - self.train_step(dummy_batch, update_params=False, dummy_batch=True) + self.train_step(dummy_batch, dummy_batch=True) self.zero_grad() - self.clear_buffered_stats() def zero_grad(self): self.optimizer.zero_grad() - def clear_buffered_stats(self): - self._buffered_stats.clear() - def lr_step(self, epoch, val_loss=None): """Adjust the learning rate based on the validation loss.""" return self.lr_scheduler.step(epoch, val_loss) @@ -362,8 +320,8 @@ class Trainer(object): return self.optimizer.get_lr() def get_model(self): - """Get the model replica.""" - return self.model + """Get the (non-wrapped) model instance.""" + return self._model def get_meter(self, name): """Get a specific meter by name.""" diff --git a/multiprocessing_train.py b/multiprocessing_train.py index 062067adf..80c46706d 100644 --- a/multiprocessing_train.py +++ b/multiprocessing_train.py @@ -19,8 +19,10 @@ from train import main as single_process_main def main(args): # Set distributed training parameters for a single node. args.distributed_world_size = torch.cuda.device_count() - args.distributed_init_method = 'tcp://localhost:{port}'.format( - port=random.randint(10000, 20000)) + port = random.randint(10000, 20000) + args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_init_host = 'localhost' + args.distributed_port = port + 1 mp = torch.multiprocessing.get_context('spawn') diff --git a/setup.py b/setup.py index 97c00f135..53e8eaf51 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ bleu = Extension( setup( name='fairseq', - version='0.5.0', + version='0.6.0', description='Facebook AI Research Sequence-to-Sequence Toolkit', long_description=readme, license=license, diff --git a/train.py b/train.py index e1d76bfe6..34b2a3e3d 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ import math import torch from fairseq import distributed_utils, options, progress_bar, tasks, utils -from fairseq.fp16_trainer import FP16Trainer +from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter @@ -43,16 +43,17 @@ def main(args): print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) + # Make a dummy batch to (i) warm the caching allocator and (ii) as a + # placeholder DistributedDataParallel when there's an uneven number of + # batches per worker. + max_positions = utils.resolve_max_positions( + task.max_positions(), + model.max_positions(), + ) + dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) + # Build trainer - if args.fp16: - if torch.cuda.get_device_capability(0)[0] < 7: - print('| WARNING: your device does NOT support faster training with --fp16,' - ' please switch to FP32 which is likely to be faster') - trainer = FP16Trainer(args, task, model, criterion) - else: - if torch.cuda.get_device_capability(0)[0] >= 7: - print('| NOTICE: your device may support faster training with --fp16') - trainer = Trainer(args, task, model, criterion) + trainer = Trainer(args, task, model, criterion, dummy_batch) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, @@ -60,10 +61,6 @@ def main(args): )) # Initialize dataloader - max_positions = utils.resolve_max_positions( - task.max_positions(), - trainer.get_model().max_positions(), - ) epoch_itr = task.get_batch_iterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, @@ -78,9 +75,7 @@ def main(args): # Load the latest checkpoint if one is available if not load_checkpoint(args, trainer, epoch_itr): - # Send a dummy batch to warm the caching allocator - dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) - trainer.dummy_train_step(dummy_batch) + trainer.dummy_train_step([dummy_batch]) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -110,32 +105,32 @@ def main(args): def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" - # Initialize data iterator - itr = epoch_itr.next_epoch_itr() - progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') - - # update parameters every N batches + # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] + # Initialize data iterator + itr = epoch_itr.next_epoch_itr() + itr = iterators.GroupedIterator(itr, update_freq) + progress = progress_bar.build_progress_bar( + args, itr, epoch_itr.epoch, no_progress_bar='simple', + ) + extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) - for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): - if i < num_batches - 1 and (i + 1) % update_freq > 0: - # buffer updates according to --update-freq - trainer.train_step(sample, update_params=False) + for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): + log_output = trainer.train_step(samples) + if log_output is None: continue - else: - log_output = trainer.train_step(sample, update_params=True) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): - if k in ['loss', 'nll_loss', 'sample_size']: + if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) @@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr): progress.print(stats) # reset training meters - for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', 'gnorm']: + for k in [ + 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', + ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() @@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets): log_output = trainer.valid_step(sample) for k, v in log_output.items(): - if k in ['loss', 'nll_loss', 'sample_size']: + if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v)