Refactor distributed code under fairseq.distributed (#1546)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1546

Test Plan: Imported from OSS

Reviewed By: girifb

Differential Revision: D25836853

Pulled By: myleott

fbshipit-source-id: c5076615d49774633ecfaf0aa68b68e8b2331bd9
This commit is contained in:
Myle Ott 2021-01-28 14:18:48 -08:00 committed by Facebook GitHub Bot
parent 922528d58f
commit d68a3530dd
10 changed files with 391 additions and 117 deletions

View File

@ -28,6 +28,7 @@ from fairseq.dataclass.initialize import hydra_init
hydra_init()
import fairseq.criterions # noqa
import fairseq.distributed # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
import fairseq.optim # noqa

View File

@ -0,0 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
__all__ = [
"DistributedTimeoutWrapper",
"LegacyDistributedDataParallel",
"ModuleProxyWrapper",
"TPUDistributedDataParallel",
]

View File

@ -0,0 +1,94 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import signal
import threading
from torch import nn
logger = logging.getLogger(__name__)
class DistributedTimeoutWrapper(nn.Module):
"""
A wrapper that kills the process if no progress is made within a given
*timeout*. The timer is reset every time :func:`forward` is called.
Usage::
module = DistributedTimeoutWrapper(module, timeout=30)
x = module(input)
time.sleep(20) # safe
x = module(input)
time.sleep(45) # job will be killed before this returns
Args:
module (nn.Module): module to wrap
timeout (int): number of seconds before killing the process
(set to a value <= 0 to disable the timeout)
signal (Optional): signal to send once timeout is triggered
"""
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGKILL):
super().__init__()
self.module = module
self.timeout = timeout
self.signal = signal
if timeout > 0:
self._heartbeat = threading.Event()
self._heartbeat_thread = threading.Thread(
target=self._check_heartbeat,
args=(os.getpid(),),
daemon=True,
)
self._heartbeat_thread.start()
self._terminated = False
else:
self._heartbeat = None
self._heartbeat_thread = None
def __del__(self):
self.stop_timeout()
def __getattr__(self, name):
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def stop_timeout(self):
if self._heartbeat_thread is not None:
self._terminated = True
self._heartbeat_thread.join()
def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
return self.module.load_state_dict(*args, **kwargs)
def forward(self, *args, **kwargs):
if self._heartbeat is not None:
self._heartbeat.set()
return self.module(*args, **kwargs)
def _check_heartbeat(self, parent_pid):
self._heartbeat.wait() # wait for the first forward pass
while True:
self._heartbeat.clear()
success = self._heartbeat.wait(timeout=self.timeout)
if self._terminated:
break
elif not success:
logger.error((
"Killing job for not making progress in {} seconds. "
"Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self.timeout)))
os.kill(parent_pid, self.signal)
return

View File

@ -14,15 +14,13 @@ This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
import copy
from collections import OrderedDict
from contextlib import contextmanager
import torch
from torch import nn
from torch.autograd import Variable
from . import distributed_utils
from fairseq import distributed_utils
class LegacyDistributedDataParallel(nn.Module):
@ -64,13 +62,6 @@ class LegacyDistributedDataParallel(nn.Module):
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
def __getstate__(self):
attrs = copy.copy(self.__dict__)
return attrs
def __setstate__(self, state):
super().__setstate__(state)
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""

View File

@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
class ModuleProxyWrapper(nn.Module):
"""
Wrap a DistributedDataParallel module and forward requests for missing
attributes to the module wrapped by DDP (the twice-wrapped module).
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
Usage::
module.xyz = "hello world"
wrapped_module = DistributedDataParallel(module, **ddp_args)
wrapped_module = ModuleProxyWrapper(wrapped_module)
assert wrapped_module.xyz == "hello world"
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
Args:
module (nn.Module): module to wrap
"""
def __init__(self, module: nn.Module):
super().__init__()
assert hasattr(module, "module"), \
"ModuleProxyWrapper expects input to wrap another module"
self.module = module
def __getattr__(self, name):
"""Forward missing attributes to twice-wrapped module."""
try:
# defer to nn.Module's logic
return super().__getattr__(name)
except AttributeError:
try:
# forward to the once-wrapped module
return getattr(self.module, name)
except AttributeError:
# forward to the twice-wrapped module
return getattr(self.module.module, name)
def state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.load_state_dict(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -0,0 +1,43 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from fairseq import distributed_utils
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = distributed_utils.get_world_size(self.process_group)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
gradients = []
for p in self.parameters():
if not p.requires_grad:
continue
if p.grad is None:
p.grad = torch.zeros_like(p)
if p.grad.requires_grad:
raise RuntimeError(
"TPUDistributedDataParallel only works with gradients that don't "
"require grad"
)
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
'sum',
gradients,
scale=1. / self.world_size,
groups=self.process_group[1],
)

View File

@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import signal
@ -11,9 +10,15 @@ import threading
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fairseq import distributed_utils
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from fairseq.distributed import (
DistributedTimeoutWrapper,
LegacyDistributedDataParallel,
ModuleProxyWrapper,
TPUDistributedDataParallel,
)
logger = logging.getLogger(__name__)
@ -26,7 +31,7 @@ except ImportError:
_GOSSIP_DISABLED = True
def DistributedFairseqModel(args, model, process_group):
def DistributedFairseqModel(args, model, process_group, device):
"""
Wrap a *model* to support distributed data parallel training.
@ -40,42 +45,42 @@ def DistributedFairseqModel(args, model, process_group):
model (BaseFairseqModel): model to wrap
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
device: device to move model to
"""
# determine which DDP class to extend
assert isinstance(model, nn.Module)
if args.tpu:
ddp_class = TPUDistributedDataParallel
init_kwargs = dict(
module=model,
wrapped_model = TPUDistributedDataParallel(
module=model.to(device),
process_group=process_group,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {"c10d", "pytorch_ddp"}:
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
wrapped_model = DistributedDataParallel(
module=model.to(device),
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
bucket_cap_mb=args.bucket_cap_mb,
process_group=process_group,
find_unused_parameters=args.find_unused_parameters,
)
# Maintain backward compatibility
if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
init_kwargs["find_unused_parameters"] = args.find_unused_parameters
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}:
ddp_class = LegacyDistributedDataParallel
init_kwargs = dict(
module=model,
wrapped_model = LegacyDistributedDataParallel(
module=model.to(device),
buffer_size=2 ** 28,
process_group=process_group,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "slow_mo":
if _GOSSIP_DISABLED:
raise ImportError(
"Cannot find gossip library. Please install from: "
"github.com/facebookresearch/stochastic_gradient_push"
)
ddp_class = gossip.GossipDataParallel
# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
@ -89,8 +94,8 @@ def DistributedFairseqModel(args, model, process_group):
else:
args.slowmo_momentum = 0.6
init_kwargs = dict(
module=model,
wrapped_model = gossip.GossipDataParallel(
module=model.to(device),
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
@ -99,88 +104,14 @@ def DistributedFairseqModel(args, model, process_group):
localsgd=(args.slowmo_algorithm == "LocalSGD"),
localsgd_frequency=args.localsgd_frequency,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
heartbeat_timeout = getattr(args, "heartbeat_timeout", -1)
# kill hung distributed jobs after a timeout
wrapped_model = DistributedTimeoutWrapper(
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
)
class _DistributedFairseqModel(ddp_class):
"""
Extend DistributedDataParallel to check for missing attributes in the
wrapped module and to add a timeout to kill the job if no progress is
made (--heartbeat-timeout).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._heartbeat_timeout = heartbeat_timeout
if self._heartbeat_timeout > 0:
self._heartbeat = threading.Event()
self._heartbeat_thread = threading.Thread(
target=self._check_heartbeat,
args=(os.getpid(),),
daemon=True,
)
self._heartbeat_thread.start()
else:
self._heartbeat = None
def _check_heartbeat(self, parent_pid):
self._heartbeat.wait() # wait for the first forward pass
while True:
self._heartbeat.clear()
success = self._heartbeat.wait(timeout=self._heartbeat_timeout)
if not success:
logger.error((
"Killing job for not making progress in {} seconds. "
"Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self._heartbeat_timeout)))
os.kill(parent_pid, signal.SIGKILL)
return
def __getattr__(self, name):
wrapped_module = super().__getattr__("module")
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name)
def forward(self, *args, **kwargs):
if self._heartbeat is not None:
self._heartbeat.set()
return super().forward(*args, **kwargs)
return _DistributedFairseqModel(**init_kwargs)
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = distributed_utils.get_world_size(self.process_group)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
gradients = []
for p in self.parameters():
if not p.requires_grad:
continue
if p.grad is None:
p.grad = torch.zeros_like(p)
if p.grad.requires_grad:
raise RuntimeError(
"TPUDistributedDataParallel only works with gradients that don't "
"require grad"
)
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
'sum',
gradients,
scale=1. / self.world_size,
groups=self.process_group[1],
)
return wrapped_model

View File

@ -69,7 +69,12 @@ class Trainer(object):
elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
if not cfg.distributed_training.pipeline_model_parallel:
if (
not cfg.distributed_training.pipeline_model_parallel
# the DistributedFairseqModel wrapper will handle moving to device,
# so only handle cases which don't use the wrapper
and not self.use_distributed_wrapper
):
self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device)
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
@ -158,18 +163,25 @@ class Trainer(object):
# parallel rank 0
return self.data_parallel_rank == 0
@property
def use_distributed_wrapper(self) -> bool:
return (
self.data_parallel_world_size > 1
and not self.cfg.optimization.use_bmuf
)
@property
def criterion(self):
if self._wrapped_criterion is None:
if (
utils.has_parameters(self._criterion)
and self.data_parallel_world_size > 1
and not self.cfg.optimization.use_bmuf
and self.use_distributed_wrapper
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.cfg.distributed_training,
self._criterion,
process_group=self.data_parallel_process_group,
device=self.device,
)
else:
self._wrapped_criterion = self._criterion
@ -178,11 +190,12 @@ class Trainer(object):
@property
def model(self):
if self._wrapped_model is None:
if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf:
if self.use_distributed_wrapper:
self._wrapped_model = models.DistributedFairseqModel(
self.cfg.distributed_training,
self._model,
process_group=self.data_parallel_process_group,
device=self.device,
)
else:
self._wrapped_model = self._model
@ -268,8 +281,8 @@ class Trainer(object):
checkpoint_utils.save_state(
filename,
self.cfg,
self.get_model().state_dict(),
self.get_criterion(),
self.model.state_dict(),
self.criterion,
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
@ -336,7 +349,7 @@ class Trainer(object):
# load model parameters
try:
self.get_model().load_state_dict(
self.model.load_state_dict(
state["model"], strict=True, model_cfg=self.cfg.model
)
if utils.has_parameters(self.get_criterion()):

View File

@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import signal
import time
import unittest
import torch
from torch import nn
from fairseq.distributed import DistributedTimeoutWrapper
class ModuleWithDelay(nn.Module):
def __init__(self, delay):
super().__init__()
self.delay = delay
def forward(self, x):
time.sleep(self.delay)
return x
class TestDistributedTimeoutWrapper(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_no_timeout(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_safe(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_killed(self):
with self.assertRaises(KeyboardInterrupt):
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,75 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from torch import nn
from fairseq.distributed import ModuleProxyWrapper
from .utils import objects_are_equal
class MockDDPWrapper(nn.Module):
"""A simple wrapper with an interface similar to DistributedDataParallel."""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return self.module(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 10)
self.xyz = "hello"
def forward(self, x):
return self.linear(x)
def get_xyz(self):
return self.xyz
class TestModuleProxyWrapper(unittest.TestCase):
def _get_module(self):
module = Model()
wrapped_module = MockDDPWrapper(module)
wrapped_module = ModuleProxyWrapper(wrapped_module)
return wrapped_module, module
def test_getattr_forwarding(self):
wrapped_module, module = self._get_module()
assert module.xyz == "hello"
assert module.get_xyz() == "hello"
assert wrapped_module.xyz == "hello"
wrapped_module.xyz = "world"
assert wrapped_module.xyz == "world"
assert module.get_xyz() == "hello"
def test_state_dict(self):
wrapped_module, module = self._get_module()
assert objects_are_equal(wrapped_module.state_dict(), module.state_dict())
def test_load_state_dict(self):
wrapped_module, module = self._get_module()
wrapped_module.load_state_dict(module.state_dict())
input = torch.rand(4, 5)
torch.testing.assert_allclose(wrapped_module(input), module(input))
def test_forward(self):
wrapped_module, module = self._get_module()
input = torch.rand(4, 5)
torch.testing.assert_allclose(wrapped_module(input), module(input))
if __name__ == "__main__":
unittest.main()