diff --git a/fairseq/__init__.py b/fairseq/__init__.py index ccd45add7..8e51b61be 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -28,6 +28,7 @@ from fairseq.dataclass.initialize import hydra_init hydra_init() import fairseq.criterions # noqa +import fairseq.distributed # noqa import fairseq.models # noqa import fairseq.modules # noqa import fairseq.optim # noqa diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py new file mode 100644 index 000000000..7f4016e38 --- /dev/null +++ b/fairseq/distributed/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .legacy_distributed_data_parallel import LegacyDistributedDataParallel +from .module_proxy_wrapper import ModuleProxyWrapper +from .tpu_distributed_data_parallel import TPUDistributedDataParallel + + +__all__ = [ + "DistributedTimeoutWrapper", + "LegacyDistributedDataParallel", + "ModuleProxyWrapper", + "TPUDistributedDataParallel", +] diff --git a/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/distributed/distributed_timeout_wrapper.py new file mode 100644 index 000000000..c8ab47707 --- /dev/null +++ b/fairseq/distributed/distributed_timeout_wrapper.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +from torch import nn + + +logger = logging.getLogger(__name__) + + +class DistributedTimeoutWrapper(nn.Module): + """ + A wrapper that kills the process if no progress is made within a given + *timeout*. The timer is reset every time :func:`forward` is called. + + Usage:: + + module = DistributedTimeoutWrapper(module, timeout=30) + x = module(input) + time.sleep(20) # safe + x = module(input) + time.sleep(45) # job will be killed before this returns + + Args: + module (nn.Module): module to wrap + timeout (int): number of seconds before killing the process + (set to a value <= 0 to disable the timeout) + signal (Optional): signal to send once timeout is triggered + """ + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGKILL): + super().__init__() + self.module = module + self.timeout = timeout + self.signal = signal + + if timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + self._terminated = False + else: + self._heartbeat = None + self._heartbeat_thread = None + + def __del__(self): + self.stop_timeout() + + def __getattr__(self, name): + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def stop_timeout(self): + if self._heartbeat_thread is not None: + self._terminated = True + self._heartbeat_thread.join() + + def state_dict(self, *args, **kwargs): + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return self.module(*args, **kwargs) + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self.timeout) + if self._terminated: + break + elif not success: + logger.error(( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self.timeout))) + os.kill(parent_pid, self.signal) + return diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py similarity index 96% rename from fairseq/legacy_distributed_data_parallel.py rename to fairseq/distributed/legacy_distributed_data_parallel.py index 7e176eaf3..35de179b2 100644 --- a/fairseq/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -14,15 +14,13 @@ This version also supports the *no_sync* context manager, which allows faster training with `--update-freq`. """ -import copy from collections import OrderedDict from contextlib import contextmanager import torch from torch import nn -from torch.autograd import Variable -from . import distributed_utils +from fairseq import distributed_utils class LegacyDistributedDataParallel(nn.Module): @@ -64,13 +62,6 @@ class LegacyDistributedDataParallel(nn.Module): paramlists[device] += [param] self.per_device_params = list(paramlists.values()) - def __getstate__(self): - attrs = copy.copy(self.__dict__) - return attrs - - def __setstate__(self, state): - super().__setstate__(state) - @contextmanager def no_sync(self): """A context manager to disable gradient synchronization.""" diff --git a/fairseq/distributed/module_proxy_wrapper.py b/fairseq/distributed/module_proxy_wrapper.py new file mode 100644 index 000000000..fc2c6f8c7 --- /dev/null +++ b/fairseq/distributed/module_proxy_wrapper.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class ModuleProxyWrapper(nn.Module): + """ + Wrap a DistributedDataParallel module and forward requests for missing + attributes to the module wrapped by DDP (the twice-wrapped module). + Also forward calls to :func:`state_dict` and :func:`load_state_dict`. + + Usage:: + + module.xyz = "hello world" + wrapped_module = DistributedDataParallel(module, **ddp_args) + wrapped_module = ModuleProxyWrapper(wrapped_module) + assert wrapped_module.xyz == "hello world" + assert wrapped_module.state_dict().keys() == module.state_dict().keys() + + Args: + module (nn.Module): module to wrap + """ + + def __init__(self, module: nn.Module): + super().__init__() + assert hasattr(module, "module"), \ + "ModuleProxyWrapper expects input to wrap another module" + self.module = module + + def __getattr__(self, name): + """Forward missing attributes to twice-wrapped module.""" + try: + # defer to nn.Module's logic + return super().__getattr__(name) + except AttributeError: + try: + # forward to the once-wrapped module + return getattr(self.module, name) + except AttributeError: + # forward to the twice-wrapped module + return getattr(self.module.module, name) + + def state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/distributed/tpu_distributed_data_parallel.py new file mode 100644 index 000000000..2adcf1cb5 --- /dev/null +++ b/fairseq/distributed/tpu_distributed_data_parallel.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from fairseq import distributed_utils + + +class TPUDistributedDataParallel(nn.Module): + + def __init__(self, module, process_group): + super().__init__() + self.module = module + self.process_group = process_group + self.world_size = distributed_utils.get_world_size(self.process_group) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + gradients = [] + for p in self.parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + if p.grad.requires_grad: + raise RuntimeError( + "TPUDistributedDataParallel only works with gradients that don't " + "require grad" + ) + gradients.append(p.grad) + + import torch_xla.core.xla_model as xm + xm.all_reduce( + 'sum', + gradients, + scale=1. / self.world_size, + groups=self.process_group[1], + ) diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index b8fbc3779..bee103311 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import inspect import logging import os import signal @@ -11,9 +10,15 @@ import threading import torch import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel from fairseq import distributed_utils -from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel +from fairseq.distributed import ( + DistributedTimeoutWrapper, + LegacyDistributedDataParallel, + ModuleProxyWrapper, + TPUDistributedDataParallel, +) logger = logging.getLogger(__name__) @@ -26,7 +31,7 @@ except ImportError: _GOSSIP_DISABLED = True -def DistributedFairseqModel(args, model, process_group): +def DistributedFairseqModel(args, model, process_group, device): """ Wrap a *model* to support distributed data parallel training. @@ -40,42 +45,42 @@ def DistributedFairseqModel(args, model, process_group): model (BaseFairseqModel): model to wrap process_group: the c10d process group to be used for distributed data parallel all-reduction. + device: device to move model to """ - # determine which DDP class to extend assert isinstance(model, nn.Module) if args.tpu: - ddp_class = TPUDistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = TPUDistributedDataParallel( + module=model.to(device), process_group=process_group, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend in {"c10d", "pytorch_ddp"}: - ddp_class = nn.parallel.DistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = DistributedDataParallel( + module=model.to(device), device_ids=[args.device_id], output_device=args.device_id, broadcast_buffers=args.broadcast_buffers, bucket_cap_mb=args.bucket_cap_mb, process_group=process_group, + find_unused_parameters=args.find_unused_parameters, ) - # Maintain backward compatibility - if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: - init_kwargs["find_unused_parameters"] = args.find_unused_parameters + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: - ddp_class = LegacyDistributedDataParallel - init_kwargs = dict( - module=model, + wrapped_model = LegacyDistributedDataParallel( + module=model.to(device), buffer_size=2 ** 28, process_group=process_group, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) elif args.ddp_backend == "slow_mo": if _GOSSIP_DISABLED: raise ImportError( "Cannot find gossip library. Please install from: " "github.com/facebookresearch/stochastic_gradient_push" ) - ddp_class = gossip.GossipDataParallel # The values of slowmo_momentum below were obtained by tuning on the # En-De 16 dataset by training the transformer_wmt_en_de_large model @@ -89,8 +94,8 @@ def DistributedFairseqModel(args, model, process_group): else: args.slowmo_momentum = 0.6 - init_kwargs = dict( - module=model, + wrapped_model = gossip.GossipDataParallel( + module=model.to(device), device_ids=[args.device_id], output_device=args.device_id, broadcast_buffers=args.broadcast_buffers, @@ -99,88 +104,14 @@ def DistributedFairseqModel(args, model, process_group): localsgd=(args.slowmo_algorithm == "LocalSGD"), localsgd_frequency=args.localsgd_frequency, ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) - heartbeat_timeout = getattr(args, "heartbeat_timeout", -1) + # kill hung distributed jobs after a timeout + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) - class _DistributedFairseqModel(ddp_class): - """ - Extend DistributedDataParallel to check for missing attributes in the - wrapped module and to add a timeout to kill the job if no progress is - made (--heartbeat-timeout). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._heartbeat_timeout = heartbeat_timeout - if self._heartbeat_timeout > 0: - self._heartbeat = threading.Event() - self._heartbeat_thread = threading.Thread( - target=self._check_heartbeat, - args=(os.getpid(),), - daemon=True, - ) - self._heartbeat_thread.start() - else: - self._heartbeat = None - - def _check_heartbeat(self, parent_pid): - self._heartbeat.wait() # wait for the first forward pass - while True: - self._heartbeat.clear() - success = self._heartbeat.wait(timeout=self._heartbeat_timeout) - if not success: - logger.error(( - "Killing job for not making progress in {} seconds. " - "Set --heartbeat-timeout=-1 to disable this timeout." - ).format(int(self._heartbeat_timeout))) - os.kill(parent_pid, signal.SIGKILL) - return - - def __getattr__(self, name): - wrapped_module = super().__getattr__("module") - if hasattr(wrapped_module, name): - return getattr(wrapped_module, name) - return super().__getattr__(name) - - def forward(self, *args, **kwargs): - if self._heartbeat is not None: - self._heartbeat.set() - return super().forward(*args, **kwargs) - - return _DistributedFairseqModel(**init_kwargs) - - -class TPUDistributedDataParallel(nn.Module): - - def __init__(self, module, process_group): - super().__init__() - self.module = module - self.process_group = process_group - self.world_size = distributed_utils.get_world_size(self.process_group) - - def forward(self, *inputs, **kwargs): - return self.module(*inputs, **kwargs) - - def all_reduce_grads(self): - gradients = [] - for p in self.parameters(): - if not p.requires_grad: - continue - if p.grad is None: - p.grad = torch.zeros_like(p) - if p.grad.requires_grad: - raise RuntimeError( - "TPUDistributedDataParallel only works with gradients that don't " - "require grad" - ) - gradients.append(p.grad) - - import torch_xla.core.xla_model as xm - xm.all_reduce( - 'sum', - gradients, - scale=1. / self.world_size, - groups=self.process_group[1], - ) + return wrapped_model diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 274d556ea..b441eaa4d 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -69,7 +69,12 @@ class Trainer(object): elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - if not cfg.distributed_training.pipeline_model_parallel: + if ( + not cfg.distributed_training.pipeline_model_parallel + # the DistributedFairseqModel wrapper will handle moving to device, + # so only handle cases which don't use the wrapper + and not self.use_distributed_wrapper + ): self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel @@ -158,18 +163,25 @@ class Trainer(object): # parallel rank 0 return self.data_parallel_rank == 0 + @property + def use_distributed_wrapper(self) -> bool: + return ( + self.data_parallel_world_size > 1 + and not self.cfg.optimization.use_bmuf + ) + @property def criterion(self): if self._wrapped_criterion is None: if ( utils.has_parameters(self._criterion) - and self.data_parallel_world_size > 1 - and not self.cfg.optimization.use_bmuf + and self.use_distributed_wrapper ): self._wrapped_criterion = models.DistributedFairseqModel( self.cfg.distributed_training, self._criterion, process_group=self.data_parallel_process_group, + device=self.device, ) else: self._wrapped_criterion = self._criterion @@ -178,11 +190,12 @@ class Trainer(object): @property def model(self): if self._wrapped_model is None: - if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf: + if self.use_distributed_wrapper: self._wrapped_model = models.DistributedFairseqModel( self.cfg.distributed_training, self._model, process_group=self.data_parallel_process_group, + device=self.device, ) else: self._wrapped_model = self._model @@ -268,8 +281,8 @@ class Trainer(object): checkpoint_utils.save_state( filename, self.cfg, - self.get_model().state_dict(), - self.get_criterion(), + self.model.state_dict(), + self.criterion, self.optimizer, self.lr_scheduler, self.get_num_updates(), @@ -336,7 +349,7 @@ class Trainer(object): # load model parameters try: - self.get_model().load_state_dict( + self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) if utils.has_parameters(self.get_criterion()): diff --git a/tests/distributed/test_distributed_timeout_wrapper.py b/tests/distributed/test_distributed_timeout_wrapper.py new file mode 100644 index 000000000..27908b9d3 --- /dev/null +++ b/tests/distributed/test_distributed_timeout_wrapper.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import signal +import time +import unittest + +import torch +from torch import nn + +from fairseq.distributed import DistributedTimeoutWrapper + + +class ModuleWithDelay(nn.Module): + + def __init__(self, delay): + super().__init__() + self.delay = delay + + def forward(self, x): + time.sleep(self.delay) + return x + + +class TestDistributedTimeoutWrapper(unittest.TestCase): + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_no_timeout(self): + module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + def test_timeout_safe(self): + module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + def test_timeout_killed(self): + with self.assertRaises(KeyboardInterrupt): + module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT) + module(torch.rand(5)) + module.stop_timeout() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/test_module_proxy_wrapper.py b/tests/distributed/test_module_proxy_wrapper.py new file mode 100644 index 000000000..2803a044c --- /dev/null +++ b/tests/distributed/test_module_proxy_wrapper.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch import nn + +from fairseq.distributed import ModuleProxyWrapper + +from .utils import objects_are_equal + + +class MockDDPWrapper(nn.Module): + """A simple wrapper with an interface similar to DistributedDataParallel.""" + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + return self.module(x) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + self.xyz = "hello" + + def forward(self, x): + return self.linear(x) + + def get_xyz(self): + return self.xyz + + +class TestModuleProxyWrapper(unittest.TestCase): + + def _get_module(self): + module = Model() + wrapped_module = MockDDPWrapper(module) + wrapped_module = ModuleProxyWrapper(wrapped_module) + return wrapped_module, module + + def test_getattr_forwarding(self): + wrapped_module, module = self._get_module() + assert module.xyz == "hello" + assert module.get_xyz() == "hello" + assert wrapped_module.xyz == "hello" + + wrapped_module.xyz = "world" + assert wrapped_module.xyz == "world" + assert module.get_xyz() == "hello" + + def test_state_dict(self): + wrapped_module, module = self._get_module() + assert objects_are_equal(wrapped_module.state_dict(), module.state_dict()) + + def test_load_state_dict(self): + wrapped_module, module = self._get_module() + wrapped_module.load_state_dict(module.state_dict()) + input = torch.rand(4, 5) + torch.testing.assert_allclose(wrapped_module(input), module(input)) + + def test_forward(self): + wrapped_module, module = self._get_module() + input = torch.rand(4, 5) + torch.testing.assert_allclose(wrapped_module(input), module(input)) + + +if __name__ == "__main__": + unittest.main()