Switch to DistributedDataParallelC10d and bump version 0.5.0 -> 0.6.0

- no more FP16Trainer, we just have an FP16Optimizer wrapper
- most of the distributed code is moved to a new wrapper class called DistributedFairseqModel, which behaves like DistributedDataParallel and a FairseqModel at the same time
- Trainer now requires an extra dummy_batch argument at initialization, which we do fwd/bwd on when there's an uneven number of batches per worker. We hide the gradients from these dummy batches by multiplying the loss by 0
- Trainer.train_step now takes a list of samples, which will allow cleaner --update-freq
This commit is contained in:
Sergey Edunov 2018-09-06 11:46:54 -07:00 committed by Myle Ott
parent 311d2c6ca9
commit 1082ba352c
20 changed files with 589 additions and 426 deletions

View File

@ -30,7 +30,7 @@ def main(args):
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')

View File

@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
# built documents.
#
# The short X.Y version.
version = '0.5.0'
version = '0.6.0'
# The full version, including alpha/beta/rc tags.
release = '0.5.0'
release = '0.6.0'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.

View File

@ -36,5 +36,7 @@ Iterators
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
:members:
.. autoclass:: fairseq.data.GroupedIterator
:members:
.. autoclass:: fairseq.data.ShardedIterator
:members:

View File

@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:

View File

@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:

View File

@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

View File

@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'GroupedIterator',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator',
'TokenBlockDataset',
]

View File

@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import itertools
import math
import numpy as np
import torch
@ -150,6 +151,36 @@ class EpochBatchIterator(object):
))
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable)
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.

View File

@ -7,7 +7,9 @@
import pickle
import torch.distributed
import torch
from torch import distributed
from torch.distributed import group
from fairseq import utils
@ -16,22 +18,39 @@ def is_master(args):
return args.distributed_rank == 0
_use_c10d = [None]
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank()
if _use_c10d[0]:
distributed.c10d.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else:
distributed.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
if not is_master(args):
suppress_output()
@ -52,35 +71,77 @@ def suppress_output():
__builtin__.print = print
def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
def get_rank():
if _use_c10d[0]:
return distributed.c10d.get_rank()
else:
return distributed.get_rank()
def get_world_size():
if _use_c10d[0]:
return distributed.c10d.get_world_size()
else:
return distributed.get_world_size()
def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
else:
return distributed.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
buffer = all_gather_list._buffer
buffer.zero_()
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
all_reduce(buffer, group=group)
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
if size > 0:
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result

View File

@ -1,154 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
def _build_optimizer(self):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self._optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
if self.args.distributed_world_size > 1:
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)
# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
# copy FP32 params back into FP16 model
offset = 0
for p in self.model.parameters():
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel

View File

@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
MODEL_REGISTRY = {}

View File

@ -0,0 +1,62 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel
from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel):
"""
A wrapper around a :class:`BaseFairseqModel` instance that adds support for
distributed training.
Anytime a method or attribute is called on this class we first try to
forward it to the underlying DistributedDataParallel instance, otherwise we
forward it to the original :class:`BaseFairseqModel` instance.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
def __init__(self, args, model):
super().__init__()
assert isinstance(model, BaseFairseqModel)
if args.no_c10d:
self.ddp_model = parallel.DistributedDataParallel(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
)
else:
self.ddp_model = parallel._DistributedDataParallelC10d(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
pass
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)

View File

@ -9,6 +9,7 @@ import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
OPTIMIZER_REGISTRY = {}
@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params)

View File

@ -5,7 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
import math
import torch
class FairseqOptimizer(object):
@ -13,7 +15,7 @@ class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
self.params = list(params)
@staticmethod
def add_args(parser):
@ -67,10 +69,25 @@ class FairseqOptimizer(object):
for group in self.optimizer.param_groups:
group.update(optimizer_overrides)
def backward(self, loss):
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
for p in self.params:
p.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params))
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
return self.optimizer.zero_grad()
self.optimizer.zero_grad()

View File

@ -0,0 +1,164 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import optim, utils
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Optimizer(optim.FairseqOptimizer):
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.scaler = DynamicLossScaler(
init_scale=2.**7,
scale_window=(2**14 / args.distributed_world_size),
)
@staticmethod
def build_optimizer(args, params):
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return FP16Optimizer(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.grad.data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
self.fp32_params.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False

View File

@ -183,6 +183,10 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true',
help='don\'t use c10d distributed backend')
group.add_argument('--c10d-bucket-cap-mb', default=150, metavar='MB',
help='bucket size for c10d backend')
return group

View File

@ -15,7 +15,7 @@ from itertools import chain
import torch
from fairseq import distributed_utils, optim, utils
from fairseq import distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler
@ -23,22 +23,27 @@ from fairseq.optim import lr_scheduler
class Trainer(object):
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
This class supports synchronous distributed data parallel training,
where multiple workers each have a full model replica and gradients
are accumulated across workers before each update. We use
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
communication of the gradients across workers.
"""
def __init__(self, args, task, model, criterion):
def __init__(self, args, task, model, criterion, dummy_batch):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
self.task = task
# copy model and criterion to current device
self.task = task
self.model = model.cuda()
self.criterion = criterion.cuda()
if args.fp16:
self._model = model.half().cuda()
else:
self._model = model.cuda()
# initialize meters
self.meters = OrderedDict()
@ -53,14 +58,27 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None
self._dummy_batch = dummy_batch
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._wrapped_model = None
@property
def model(self):
if self._wrapped_model is None:
if self.args.distributed_world_size > 1:
self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model,
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
@ -69,7 +87,17 @@ class Trainer(object):
return self._optimizer
def _build_optimizer(self):
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state):
@ -77,31 +105,27 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
filename, self.args, self.get_model(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
utils.load_model_state(filename, self.get_model())
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates']
@ -117,7 +141,7 @@ class Trainer(object):
return extra_state
def train_step(self, sample, update_params=True, dummy_batch=False):
def train_step(self, samples, dummy_batch=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
@ -125,230 +149,164 @@ class Trainer(object):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.model.train()
self.zero_grad()
if not dummy_batch:
self.meters['train_wall'].start()
# forward and backward pass
sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
logging_outputs, sample_sizes, ooms = [], [], 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
ignore_grad = True
else:
ignore_grad = False
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
try:
# forward
loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
if ignore_grad:
loss *= 0
# update parameters
if update_params:
agg_logging_output = self._update_params()
else:
agg_logging_output = None # buffering updates
if self.args.distributed_world_size > 1:
# only all-reduce gradients in the last backwards pass
if i < len(samples) - 1:
self.model.need_reduction = False
else:
self.model.need_reduction = True
if not dummy_batch:
self.meters['train_wall'].stop()
# backward
self.optimizer.backward(loss)
return agg_logging_output
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
ooms += 1
self.zero_grad()
else:
raise e
if dummy_batch:
return None
def _update_params(self):
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms],
))
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
if ooms == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping update')
self.zero_grad()
return None
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
# aggregate logging outputs and sample sizes
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
sample_size = self.criterion.__class__.grad_denom(sample_sizes)
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences'
).format(self.criterion.__class__.__name__))
try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom)
self._opt()
# normalize grads by sample size
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
# take an optimization step
self.optimizer.step()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
self.meters['oom'].update(ooms)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))
self.zero_grad()
logging_output = None
self.clear_buffered_stats()
if self.args.fp16:
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
return agg_logging_output
self.meters['train_wall'].stop()
def _forward(self, sample, eval=False):
loss = None
sample_size = 0
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
oom = 0
try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
else:
raise e
return loss, sample_size, logging_output, oom
def _backward(self, loss):
oom = 0
if loss is not None:
try:
# backward pass
loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
self.zero_grad()
else:
raise e
return oom
def _all_reduce_and_rescale(self, grad_denom):
# flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
if self.args.distributed_world_size > 1:
torch.distributed.all_reduce(flat_grads)
# rescale and clip gradients
flat_grads.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
# copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
print('WARNING: model parameter did not receive gradient: ' + name + '. '
'Check that you\'re using the param in the forward pass or set requires_grad=False')
grads.append(p.new_zeros(p.shape))
else:
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step
self.optimizer.step()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
return logging_output
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
# forward pass
sample = self._prepare_sample(sample)
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation'
self.model.eval()
# gather logging outputs from all GPUs
logging_output, sample_size = {}, 0
with torch.no_grad():
sample = self._prepare_sample(sample)
if sample is None:
sample = self._prepare_sample(self._dummy_batch)
_loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
logging_output, sample_size = zip(*distributed_utils.all_gather_list(
[logging_output, sample_size],
))
logging_output = list(logging_output)
sample_size = list(sample_size)
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
logging_output = [logging_output]
sample_size = [sample_size]
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# aggregate logging outputs and sample sizes
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_output)
sample_size = self.criterion.__class__.grad_denom(sample_size)
# update loss meters for validation
if 'loss' in agg_logging_output:
self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
# update meters for validation
ntokens = logging_output.get('ntokens', 0)
self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
return agg_logging_output
return logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False, dummy_batch=True)
self.train_step(dummy_batch, dummy_batch=True)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
@ -362,8 +320,8 @@ class Trainer(object):
return self.optimizer.get_lr()
def get_model(self):
"""Get the model replica."""
return self.model
"""Get the (non-wrapped) model instance."""
return self._model
def get_meter(self, name):
"""Get a specific meter by name."""

View File

@ -19,8 +19,10 @@ from train import main as single_process_main
def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
args.distributed_init_method = 'tcp://localhost:{port}'.format(
port=random.randint(10000, 20000))
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
mp = torch.multiprocessing.get_context('spawn')

View File

@ -35,7 +35,7 @@ bleu = Extension(
setup(
name='fairseq',
version='0.5.0',
version='0.6.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme,
license=license,

View File

@ -16,7 +16,7 @@ import math
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
@ -43,16 +43,17 @@ def main(args):
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
# Build trainer
if args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16,'
' please switch to FP32 which is likely to be faster')
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
trainer = Trainer(args, task, model, criterion, dummy_batch)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
@ -60,10 +61,6 @@ def main(args):
))
# Initialize dataloader
max_positions = utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
@ -78,9 +75,7 @@ def main(args):
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
trainer.dummy_train_step([dummy_batch])
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
@ -110,32 +105,32 @@ def main(args):
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
# Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
if log_output is None:
continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr):
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', 'gnorm']:
for k in [
'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue
extra_meters[k].update(v)