mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Refactor distributed code under fairseq.distributed (#1546)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1546 Test Plan: Imported from OSS Reviewed By: girifb Differential Revision: D25836853 Pulled By: myleott fbshipit-source-id: c5076615d49774633ecfaf0aa68b68e8b2331bd9
This commit is contained in:
parent
922528d58f
commit
d68a3530dd
@ -28,6 +28,7 @@ from fairseq.dataclass.initialize import hydra_init
|
||||
hydra_init()
|
||||
|
||||
import fairseq.criterions # noqa
|
||||
import fairseq.distributed # noqa
|
||||
import fairseq.models # noqa
|
||||
import fairseq.modules # noqa
|
||||
import fairseq.optim # noqa
|
||||
|
17
fairseq/distributed/__init__.py
Normal file
17
fairseq/distributed/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
|
||||
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
||||
from .module_proxy_wrapper import ModuleProxyWrapper
|
||||
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DistributedTimeoutWrapper",
|
||||
"LegacyDistributedDataParallel",
|
||||
"ModuleProxyWrapper",
|
||||
"TPUDistributedDataParallel",
|
||||
]
|
94
fairseq/distributed/distributed_timeout_wrapper.py
Normal file
94
fairseq/distributed/distributed_timeout_wrapper.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DistributedTimeoutWrapper(nn.Module):
|
||||
"""
|
||||
A wrapper that kills the process if no progress is made within a given
|
||||
*timeout*. The timer is reset every time :func:`forward` is called.
|
||||
|
||||
Usage::
|
||||
|
||||
module = DistributedTimeoutWrapper(module, timeout=30)
|
||||
x = module(input)
|
||||
time.sleep(20) # safe
|
||||
x = module(input)
|
||||
time.sleep(45) # job will be killed before this returns
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to wrap
|
||||
timeout (int): number of seconds before killing the process
|
||||
(set to a value <= 0 to disable the timeout)
|
||||
signal (Optional): signal to send once timeout is triggered
|
||||
"""
|
||||
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGKILL):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.timeout = timeout
|
||||
self.signal = signal
|
||||
|
||||
if timeout > 0:
|
||||
self._heartbeat = threading.Event()
|
||||
self._heartbeat_thread = threading.Thread(
|
||||
target=self._check_heartbeat,
|
||||
args=(os.getpid(),),
|
||||
daemon=True,
|
||||
)
|
||||
self._heartbeat_thread.start()
|
||||
self._terminated = False
|
||||
else:
|
||||
self._heartbeat = None
|
||||
self._heartbeat_thread = None
|
||||
|
||||
def __del__(self):
|
||||
self.stop_timeout()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward missing attributes to wrapped module."""
|
||||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
|
||||
def stop_timeout(self):
|
||||
if self._heartbeat_thread is not None:
|
||||
self._terminated = True
|
||||
self._heartbeat_thread.join()
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return self.module.state_dict(*args, **kwargs)
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
return self.module.load_state_dict(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._heartbeat is not None:
|
||||
self._heartbeat.set()
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def _check_heartbeat(self, parent_pid):
|
||||
self._heartbeat.wait() # wait for the first forward pass
|
||||
while True:
|
||||
self._heartbeat.clear()
|
||||
success = self._heartbeat.wait(timeout=self.timeout)
|
||||
if self._terminated:
|
||||
break
|
||||
elif not success:
|
||||
logger.error((
|
||||
"Killing job for not making progress in {} seconds. "
|
||||
"Set --heartbeat-timeout=-1 to disable this timeout."
|
||||
).format(int(self.timeout)))
|
||||
os.kill(parent_pid, self.signal)
|
||||
return
|
@ -14,15 +14,13 @@ This version also supports the *no_sync* context manager, which allows faster
|
||||
training with `--update-freq`.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from . import distributed_utils
|
||||
from fairseq import distributed_utils
|
||||
|
||||
|
||||
class LegacyDistributedDataParallel(nn.Module):
|
||||
@ -64,13 +62,6 @@ class LegacyDistributedDataParallel(nn.Module):
|
||||
paramlists[device] += [param]
|
||||
self.per_device_params = list(paramlists.values())
|
||||
|
||||
def __getstate__(self):
|
||||
attrs = copy.copy(self.__dict__)
|
||||
return attrs
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
|
||||
@contextmanager
|
||||
def no_sync(self):
|
||||
"""A context manager to disable gradient synchronization."""
|
55
fairseq/distributed/module_proxy_wrapper.py
Normal file
55
fairseq/distributed/module_proxy_wrapper.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModuleProxyWrapper(nn.Module):
|
||||
"""
|
||||
Wrap a DistributedDataParallel module and forward requests for missing
|
||||
attributes to the module wrapped by DDP (the twice-wrapped module).
|
||||
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
|
||||
|
||||
Usage::
|
||||
|
||||
module.xyz = "hello world"
|
||||
wrapped_module = DistributedDataParallel(module, **ddp_args)
|
||||
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
||||
assert wrapped_module.xyz == "hello world"
|
||||
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
assert hasattr(module, "module"), \
|
||||
"ModuleProxyWrapper expects input to wrap another module"
|
||||
self.module = module
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward missing attributes to twice-wrapped module."""
|
||||
try:
|
||||
# defer to nn.Module's logic
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
try:
|
||||
# forward to the once-wrapped module
|
||||
return getattr(self.module, name)
|
||||
except AttributeError:
|
||||
# forward to the twice-wrapped module
|
||||
return getattr(self.module.module, name)
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
"""Forward to the twice-wrapped module."""
|
||||
return self.module.module.state_dict(*args, **kwargs)
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
"""Forward to the twice-wrapped module."""
|
||||
return self.module.module.load_state_dict(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
43
fairseq/distributed/tpu_distributed_data_parallel.py
Normal file
43
fairseq/distributed/tpu_distributed_data_parallel.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq import distributed_utils
|
||||
|
||||
|
||||
class TPUDistributedDataParallel(nn.Module):
|
||||
|
||||
def __init__(self, module, process_group):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = process_group
|
||||
self.world_size = distributed_utils.get_world_size(self.process_group)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def all_reduce_grads(self):
|
||||
gradients = []
|
||||
for p in self.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if p.grad is None:
|
||||
p.grad = torch.zeros_like(p)
|
||||
if p.grad.requires_grad:
|
||||
raise RuntimeError(
|
||||
"TPUDistributedDataParallel only works with gradients that don't "
|
||||
"require grad"
|
||||
)
|
||||
gradients.append(p.grad)
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.all_reduce(
|
||||
'sum',
|
||||
gradients,
|
||||
scale=1. / self.world_size,
|
||||
groups=self.process_group[1],
|
||||
)
|
@ -3,7 +3,6 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@ -11,9 +10,15 @@ import threading
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fairseq import distributed_utils
|
||||
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
||||
from fairseq.distributed import (
|
||||
DistributedTimeoutWrapper,
|
||||
LegacyDistributedDataParallel,
|
||||
ModuleProxyWrapper,
|
||||
TPUDistributedDataParallel,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -26,7 +31,7 @@ except ImportError:
|
||||
_GOSSIP_DISABLED = True
|
||||
|
||||
|
||||
def DistributedFairseqModel(args, model, process_group):
|
||||
def DistributedFairseqModel(args, model, process_group, device):
|
||||
"""
|
||||
Wrap a *model* to support distributed data parallel training.
|
||||
|
||||
@ -40,42 +45,42 @@ def DistributedFairseqModel(args, model, process_group):
|
||||
model (BaseFairseqModel): model to wrap
|
||||
process_group: the c10d process group to be used for distributed data
|
||||
parallel all-reduction.
|
||||
device: device to move model to
|
||||
"""
|
||||
# determine which DDP class to extend
|
||||
assert isinstance(model, nn.Module)
|
||||
if args.tpu:
|
||||
ddp_class = TPUDistributedDataParallel
|
||||
init_kwargs = dict(
|
||||
module=model,
|
||||
wrapped_model = TPUDistributedDataParallel(
|
||||
module=model.to(device),
|
||||
process_group=process_group,
|
||||
)
|
||||
# forward missing getattr and state_dict/load_state_dict to orig model
|
||||
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
||||
elif args.ddp_backend in {"c10d", "pytorch_ddp"}:
|
||||
ddp_class = nn.parallel.DistributedDataParallel
|
||||
init_kwargs = dict(
|
||||
module=model,
|
||||
wrapped_model = DistributedDataParallel(
|
||||
module=model.to(device),
|
||||
device_ids=[args.device_id],
|
||||
output_device=args.device_id,
|
||||
broadcast_buffers=args.broadcast_buffers,
|
||||
bucket_cap_mb=args.bucket_cap_mb,
|
||||
process_group=process_group,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
)
|
||||
# Maintain backward compatibility
|
||||
if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]:
|
||||
init_kwargs["find_unused_parameters"] = args.find_unused_parameters
|
||||
# forward missing getattr and state_dict/load_state_dict to orig model
|
||||
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
||||
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}:
|
||||
ddp_class = LegacyDistributedDataParallel
|
||||
init_kwargs = dict(
|
||||
module=model,
|
||||
wrapped_model = LegacyDistributedDataParallel(
|
||||
module=model.to(device),
|
||||
buffer_size=2 ** 28,
|
||||
process_group=process_group,
|
||||
)
|
||||
# forward missing getattr and state_dict/load_state_dict to orig model
|
||||
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
||||
elif args.ddp_backend == "slow_mo":
|
||||
if _GOSSIP_DISABLED:
|
||||
raise ImportError(
|
||||
"Cannot find gossip library. Please install from: "
|
||||
"github.com/facebookresearch/stochastic_gradient_push"
|
||||
)
|
||||
ddp_class = gossip.GossipDataParallel
|
||||
|
||||
# The values of slowmo_momentum below were obtained by tuning on the
|
||||
# En-De 16 dataset by training the transformer_wmt_en_de_large model
|
||||
@ -89,8 +94,8 @@ def DistributedFairseqModel(args, model, process_group):
|
||||
else:
|
||||
args.slowmo_momentum = 0.6
|
||||
|
||||
init_kwargs = dict(
|
||||
module=model,
|
||||
wrapped_model = gossip.GossipDataParallel(
|
||||
module=model.to(device),
|
||||
device_ids=[args.device_id],
|
||||
output_device=args.device_id,
|
||||
broadcast_buffers=args.broadcast_buffers,
|
||||
@ -99,88 +104,14 @@ def DistributedFairseqModel(args, model, process_group):
|
||||
localsgd=(args.slowmo_algorithm == "LocalSGD"),
|
||||
localsgd_frequency=args.localsgd_frequency,
|
||||
)
|
||||
# forward missing getattr and state_dict/load_state_dict to orig model
|
||||
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
||||
else:
|
||||
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
|
||||
|
||||
heartbeat_timeout = getattr(args, "heartbeat_timeout", -1)
|
||||
# kill hung distributed jobs after a timeout
|
||||
wrapped_model = DistributedTimeoutWrapper(
|
||||
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
|
||||
)
|
||||
|
||||
class _DistributedFairseqModel(ddp_class):
|
||||
"""
|
||||
Extend DistributedDataParallel to check for missing attributes in the
|
||||
wrapped module and to add a timeout to kill the job if no progress is
|
||||
made (--heartbeat-timeout).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._heartbeat_timeout = heartbeat_timeout
|
||||
if self._heartbeat_timeout > 0:
|
||||
self._heartbeat = threading.Event()
|
||||
self._heartbeat_thread = threading.Thread(
|
||||
target=self._check_heartbeat,
|
||||
args=(os.getpid(),),
|
||||
daemon=True,
|
||||
)
|
||||
self._heartbeat_thread.start()
|
||||
else:
|
||||
self._heartbeat = None
|
||||
|
||||
def _check_heartbeat(self, parent_pid):
|
||||
self._heartbeat.wait() # wait for the first forward pass
|
||||
while True:
|
||||
self._heartbeat.clear()
|
||||
success = self._heartbeat.wait(timeout=self._heartbeat_timeout)
|
||||
if not success:
|
||||
logger.error((
|
||||
"Killing job for not making progress in {} seconds. "
|
||||
"Set --heartbeat-timeout=-1 to disable this timeout."
|
||||
).format(int(self._heartbeat_timeout)))
|
||||
os.kill(parent_pid, signal.SIGKILL)
|
||||
return
|
||||
|
||||
def __getattr__(self, name):
|
||||
wrapped_module = super().__getattr__("module")
|
||||
if hasattr(wrapped_module, name):
|
||||
return getattr(wrapped_module, name)
|
||||
return super().__getattr__(name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._heartbeat is not None:
|
||||
self._heartbeat.set()
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
return _DistributedFairseqModel(**init_kwargs)
|
||||
|
||||
|
||||
class TPUDistributedDataParallel(nn.Module):
|
||||
|
||||
def __init__(self, module, process_group):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = process_group
|
||||
self.world_size = distributed_utils.get_world_size(self.process_group)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def all_reduce_grads(self):
|
||||
gradients = []
|
||||
for p in self.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if p.grad is None:
|
||||
p.grad = torch.zeros_like(p)
|
||||
if p.grad.requires_grad:
|
||||
raise RuntimeError(
|
||||
"TPUDistributedDataParallel only works with gradients that don't "
|
||||
"require grad"
|
||||
)
|
||||
gradients.append(p.grad)
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
xm.all_reduce(
|
||||
'sum',
|
||||
gradients,
|
||||
scale=1. / self.world_size,
|
||||
groups=self.process_group[1],
|
||||
)
|
||||
return wrapped_model
|
||||
|
@ -69,7 +69,12 @@ class Trainer(object):
|
||||
elif cfg.common.bf16:
|
||||
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
||||
self._model = self._model.to(dtype=torch.bfloat16)
|
||||
if not cfg.distributed_training.pipeline_model_parallel:
|
||||
if (
|
||||
not cfg.distributed_training.pipeline_model_parallel
|
||||
# the DistributedFairseqModel wrapper will handle moving to device,
|
||||
# so only handle cases which don't use the wrapper
|
||||
and not self.use_distributed_wrapper
|
||||
):
|
||||
self._criterion = self._criterion.to(device=self.device)
|
||||
self._model = self._model.to(device=self.device)
|
||||
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
|
||||
@ -158,18 +163,25 @@ class Trainer(object):
|
||||
# parallel rank 0
|
||||
return self.data_parallel_rank == 0
|
||||
|
||||
@property
|
||||
def use_distributed_wrapper(self) -> bool:
|
||||
return (
|
||||
self.data_parallel_world_size > 1
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
)
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
if self._wrapped_criterion is None:
|
||||
if (
|
||||
utils.has_parameters(self._criterion)
|
||||
and self.data_parallel_world_size > 1
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
and self.use_distributed_wrapper
|
||||
):
|
||||
self._wrapped_criterion = models.DistributedFairseqModel(
|
||||
self.cfg.distributed_training,
|
||||
self._criterion,
|
||||
process_group=self.data_parallel_process_group,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self._wrapped_criterion = self._criterion
|
||||
@ -178,11 +190,12 @@ class Trainer(object):
|
||||
@property
|
||||
def model(self):
|
||||
if self._wrapped_model is None:
|
||||
if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf:
|
||||
if self.use_distributed_wrapper:
|
||||
self._wrapped_model = models.DistributedFairseqModel(
|
||||
self.cfg.distributed_training,
|
||||
self._model,
|
||||
process_group=self.data_parallel_process_group,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self._wrapped_model = self._model
|
||||
@ -268,8 +281,8 @@ class Trainer(object):
|
||||
checkpoint_utils.save_state(
|
||||
filename,
|
||||
self.cfg,
|
||||
self.get_model().state_dict(),
|
||||
self.get_criterion(),
|
||||
self.model.state_dict(),
|
||||
self.criterion,
|
||||
self.optimizer,
|
||||
self.lr_scheduler,
|
||||
self.get_num_updates(),
|
||||
@ -336,7 +349,7 @@ class Trainer(object):
|
||||
|
||||
# load model parameters
|
||||
try:
|
||||
self.get_model().load_state_dict(
|
||||
self.model.load_state_dict(
|
||||
state["model"], strict=True, model_cfg=self.cfg.model
|
||||
)
|
||||
if utils.has_parameters(self.get_criterion()):
|
||||
|
54
tests/distributed/test_distributed_timeout_wrapper.py
Normal file
54
tests/distributed/test_distributed_timeout_wrapper.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.distributed import DistributedTimeoutWrapper
|
||||
|
||||
|
||||
class ModuleWithDelay(nn.Module):
|
||||
|
||||
def __init__(self, delay):
|
||||
super().__init__()
|
||||
self.delay = delay
|
||||
|
||||
def forward(self, x):
|
||||
time.sleep(self.delay)
|
||||
return x
|
||||
|
||||
|
||||
class TestDistributedTimeoutWrapper(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
def tearDown(self):
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
def test_no_timeout(self):
|
||||
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
|
||||
module(torch.rand(5))
|
||||
module.stop_timeout()
|
||||
|
||||
def test_timeout_safe(self):
|
||||
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
|
||||
module(torch.rand(5))
|
||||
module.stop_timeout()
|
||||
|
||||
def test_timeout_killed(self):
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
|
||||
module(torch.rand(5))
|
||||
module.stop_timeout()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
75
tests/distributed/test_module_proxy_wrapper.py
Normal file
75
tests/distributed/test_module_proxy_wrapper.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.distributed import ModuleProxyWrapper
|
||||
|
||||
from .utils import objects_are_equal
|
||||
|
||||
|
||||
class MockDDPWrapper(nn.Module):
|
||||
"""A simple wrapper with an interface similar to DistributedDataParallel."""
|
||||
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, x):
|
||||
return self.module(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(5, 10)
|
||||
self.xyz = "hello"
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def get_xyz(self):
|
||||
return self.xyz
|
||||
|
||||
|
||||
class TestModuleProxyWrapper(unittest.TestCase):
|
||||
|
||||
def _get_module(self):
|
||||
module = Model()
|
||||
wrapped_module = MockDDPWrapper(module)
|
||||
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
||||
return wrapped_module, module
|
||||
|
||||
def test_getattr_forwarding(self):
|
||||
wrapped_module, module = self._get_module()
|
||||
assert module.xyz == "hello"
|
||||
assert module.get_xyz() == "hello"
|
||||
assert wrapped_module.xyz == "hello"
|
||||
|
||||
wrapped_module.xyz = "world"
|
||||
assert wrapped_module.xyz == "world"
|
||||
assert module.get_xyz() == "hello"
|
||||
|
||||
def test_state_dict(self):
|
||||
wrapped_module, module = self._get_module()
|
||||
assert objects_are_equal(wrapped_module.state_dict(), module.state_dict())
|
||||
|
||||
def test_load_state_dict(self):
|
||||
wrapped_module, module = self._get_module()
|
||||
wrapped_module.load_state_dict(module.state_dict())
|
||||
input = torch.rand(4, 5)
|
||||
torch.testing.assert_allclose(wrapped_module(input), module(input))
|
||||
|
||||
def test_forward(self):
|
||||
wrapped_module, module = self._get_module()
|
||||
input = torch.rand(4, 5)
|
||||
torch.testing.assert_allclose(wrapped_module(input), module(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user