Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667

Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded)

This enables fully parameter + optimizer state sharding by using
FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide
`--ddp-backend=fully_sharded` to enable. Other common options work
out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`,
etc.). This should be a drop-in replacement for the "c10d" backend.

This yields pretty big speedups for small models and enables training ~13B
parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model
parallelism.

This also adds a new option `--cpu-offload` that offloads the optimizer state
and FP32 model copy to CPU, which is particularly useful when combined with
`--optimizer=cpu_adam`.

Note: after enabling this, each GPU will save a checkpoint file, since the
optimizer state is sharded. Each checkpoint will contain a single shard of the
optimizer state and the rank 0 checkpoint will contain the full model weights.

Note: a known limitation of the current implementation is that you cannot
resume training on a different world_size. This constraint will be relaxed in
future iterations.

Test Plan: Imported from OSS

Reviewed By: sshleifer

Differential Revision: D26771144

Pulled By: myleott

fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46
This commit is contained in:
Myle Ott 2021-03-04 13:31:02 -08:00 committed by Facebook GitHub Bot
parent 6d23cc7e7c
commit 656d7e5779
13 changed files with 292 additions and 31 deletions

View File

@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass):
zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"}
)
fp16: bool = II("common.fp16")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field(
default=False,
metadata={"help": "don't reshard parameters after forward pass"},
)
fp32_reduce_scatter: bool = field(
default=False,
metadata={"help": "reduce-scatter grads in FP32"},
)
cpu_offload: bool = field(
default=False,
metadata={"help": "offload FP32 params to CPU"}
)
@dataclass

View File

@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]):
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum([
"c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"legacy_ddp",
"no_c10d", # alias for legacy_ddp
"pytorch_ddp",

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
@ -11,6 +12,9 @@ from .tpu_distributed_data_parallel import TPUDistributedDataParallel
__all__ = [
"DistributedTimeoutWrapper",
"fsdp_enable_wrap",
"fsdp_wrap",
"FullyShardedDataParallel",
"LegacyDistributedDataParallel",
"ModuleProxyWrapper",
"TPUDistributedDataParallel",

View File

@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Optional
import torch
from fairseq.dataclass.configs import DistributedTrainingConfig
from fairseq.distributed import utils as dist_utils
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
has_FSDP = True
except ImportError:
FSDP = torch.nn.Module
has_FSDP = False
class FullyShardedDataParallel(FSDP):
"""
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
fairseq-specific checkpoint saving/loading logic.
Args:
use_sharded_state (bool): if True, then ``state_dict`` will return
``FSDP.local_state_dict`` and ``load_state_dict`` will call
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
return the full model weights on data parallel rank 0 (empty on
other ranks) and ``load_state_dict`` will broadcast model weights
from rank 0 to other ranks.
"""
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
super().__init__(*args, **kwargs)
self.use_sharded_state = use_sharded_state
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.use_sharded_state:
return super().local_state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
if self.rank == 0:
return super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
# We must call state_dict() due to use of communication
# primitives. But we don't use the result.
super().state_dict()
return destination or {}
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
if self.use_sharded_state:
return super().load_local_state_dict(state_dict, strict=strict)
else:
state_dict = dist_utils.broadcast_object(
state_dict, src_rank=0, group=self.process_group
)
return super().load_state_dict(state_dict, strict=strict)
@contextlib.contextmanager
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False):
try:
from fairscale.nn import enable_wrap
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if cfg.memory_efficient_fp16:
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
group = dist_utils.get_data_parallel_group()
if group is None and cfg.distributed_world_size == 1:
from fairscale.utils.testing import DummyProcessGroup
group = DummyProcessGroup(rank=0, size=1)
fsdp_config = {
"process_group": group,
"reshard_after_forward": not cfg.no_reshard_after_forward,
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
"flatten_parameters": True,
"cpu_offload": cfg.cpu_offload,
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
"bucket_cap_mb": cfg.bucket_cap_mb,
}
with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
yield
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
fairscale is not available.
Args:
module (nn.Module): module to (maybe) wrap
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap
cls = FullyShardedDataParallel
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, cls=cls, **kwargs)
else:
return module
else:
return wrap(module, cls=cls, **kwargs)
except ImportError:
return module

View File

@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device):
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "fully_sharded":
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP"
wrapped_model = model
if args.memory_efficient_fp16:
wrapped_model = wrapped_model.half()
if not args.cpu_offload:
wrapped_model = wrapped_model.to(device=device)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
# kill hung distributed jobs after a timeout
wrapped_model = DistributedTimeoutWrapper(
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
)
if getattr(args, "heartbeat_timeout", -1) > 0:
wrapped_model = DistributedTimeoutWrapper(
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
)
return wrapped_model

View File

@ -27,6 +27,13 @@ from torch import Tensor
logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(module.unwrapped_module, expected_type)
else:
assert isinstance(module, expected_type)
class BaseFairseqModel(nn.Module):
"""Base class for fairseq models."""
@ -284,8 +291,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel):
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
check_type(self.encoder, FairseqEncoder)
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
@ -365,8 +373,8 @@ class FairseqMultiModel(BaseFairseqModel):
assert encoders.keys() == decoders.keys()
self.keys = list(encoders.keys())
for key in self.keys:
assert isinstance(encoders[key], FairseqEncoder)
assert isinstance(decoders[key], FairseqDecoder)
check_type(encoders[key], FairseqEncoder)
check_type(decoders[key], FairseqDecoder)
self.models = nn.ModuleDict(
{
@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel):
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **kwargs):
"""
@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder)
check_type(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
@ -240,6 +241,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
return cls(args, encoder, decoder)
@classmethod
@ -386,6 +390,7 @@ class TransformerEncoder(FairseqEncoder):
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def forward_embedding(
@ -726,6 +731,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer
def forward(

View File

@ -107,6 +107,10 @@ class CPUAdam(torch.optim.Optimizer):
self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode
)
@property
def supports_flat_params(self):
return True
@torch.no_grad()
def step(self, closure=None):
loss = None

View File

@ -322,6 +322,10 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
def all_reduce_grads(self, module):
self.fp32_optimizer.all_reduce_grads(module)
@property
def supports_flat_params(self):
return self.fp32_optimizer.supports_flat_params
class _MemoryEfficientFP16OptimizerMixin(object):
def __init__(self, *args, **kwargs):
@ -442,6 +446,10 @@ class _MemoryEfficientFP16OptimizerMixin(object):
else:
self._multiply_factor = 1.0
@property
def supports_flat_params(self):
return self.wrapped_optimizer.supports_flat_params
class MemoryEfficientFP16Optimizer(
_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer(
*supports_memory_efficient_fp16* property.
"""
def __init__(self, cfg: DictConfig, params, optimizer, **kwargs):
if not optimizer.supports_memory_efficient_fp16:
def __init__(
self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs
):
if not allow_unsupported and not optimizer.supports_memory_efficient_fp16:
raise ValueError(
"Unsupported optimizer: {}".format(optimizer.__class__.__name__)
)

View File

@ -63,15 +63,31 @@ class Trainer(object):
else:
self.device = torch.device("cpu")
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
if self.cfg.common.bf16:
raise ValueError(
"FullyShardedDataParallel is not compatible with --bf16 or "
"--memory-efficient-bf16"
)
if self.cfg.distributed_training.zero_sharding != "none":
raise ValueError(
"FullyShardedDataParallel is not compatible with --zero-sharding "
"option (it's already built in)"
)
else:
if self.cfg.distributed_training.cpu_offload:
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
# copy model and criterion to current device/dtype
self._criterion = criterion
self._model = model
if cfg.common.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
if cfg.distributed_training.ddp_backend != "fully_sharded":
if cfg.common.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
if (
not cfg.distributed_training.pipeline_model_parallel
# the DistributedFairseqModel wrapper will handle moving to device,
@ -171,17 +187,26 @@ class Trainer(object):
return (
self.data_parallel_world_size > 1
and not self.cfg.optimization.use_bmuf
) or (
self.cfg.distributed_training.ddp_backend == "fully_sharded"
and self.cfg.distributed_training.cpu_offload
)
@property
def should_save_checkpoint_on_current_rank(self) -> bool:
"""Indicates whether to save checkpoints on the current DDP rank."""
return self.is_data_parallel_master
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
return True
else:
return self.is_data_parallel_master
@property
def checkpoint_suffix(self) -> str:
"""Suffix to add to the checkpoint file name."""
return self.cfg.checkpoint.checkpoint_suffix or ""
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank)
else:
return self.cfg.checkpoint.checkpoint_suffix or ""
@property
def criterion(self):
@ -234,7 +259,20 @@ class Trainer(object):
)
)
if self.cfg.common.fp16 or self.cfg.common.bf16:
if (
self.cfg.distributed_training.ddp_backend == "fully_sharded"
and self.cfg.common.fp16
):
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
# mostly for the grad scaling. But if we don't have the
# --memory-efficient-fp16 flag set, then we're effectively doing
# regular --fp16 and can allow the use of optimizers that would
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
allow_unsupported = not self.cfg.common.memory_efficient_fp16
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.cfg, params, allow_unsupported=allow_unsupported
)
elif self.cfg.common.fp16 or self.cfg.common.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info(
"NOTE: your device does NOT support faster training with --fp16, "
@ -254,6 +292,16 @@ class Trainer(object):
logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
assert not self.cfg.optimization.use_bmuf, \
"--ddp-backend=fully_sharded is not compatible with BMUF"
assert self._optimizer.supports_flat_params, (
"--ddp-backend=fully_sharded is only compatible with pointwise "
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
"However, the sharding will result in slightly different results when "
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
)
if self.cfg.optimization.use_bmuf:
self._optimizer = optim.FairseqBMUF(
self.cfg.bmuf,
@ -355,6 +403,8 @@ class Trainer(object):
# TPUs don't support broadcast yet, so load checkpoints
# on every worker for now
or self.tpu
# FSDP requires loading checkpoint shards on all ranks
or self.cfg.distributed_training.ddp_backend == "fully_sharded"
)
if load_on_all_ranks or self.data_parallel_rank == 0:
@ -965,7 +1015,21 @@ class Trainer(object):
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None)
def agg_norm_fn(total_norm):
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
total_norm = total_norm ** 2
if (
self.data_parallel_process_group is not None
or torch.distributed.is_initialized()
):
total_norm = distributed_utils.all_reduce(
total_norm.cuda(), group=self.data_parallel_process_group
)
total_norm = total_norm ** 0.5
return total_norm
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn)
def cumulative_training_time(self):
if self._cumulative_training_time is None:

View File

@ -18,7 +18,6 @@ import numpy as np
import torch
from fairseq import (
checkpoint_utils,
distributed_utils,
options,
quantization_utils,
tasks,
@ -27,7 +26,7 @@ from fairseq import (
from fairseq.data import iterators
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed_utils import is_master
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None:
utils.import_user_module(cfg.common)
if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None:
assert cfg.criterion, "Please specify criterion to train a model"
# Build model and criterion
model = task.build_model(cfg.model)
if cfg.distributed_training.ddp_backend == "fully_sharded":
with fsdp_enable_wrap(cfg.distributed_training):
model = fsdp_wrap(task.build_model(cfg.model))
else:
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(model)
logger.info("task: {}".format(task.__class__.__name__))
@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None:
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info(
"num. model params: {:,} (num. trained: {:,})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()),
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad),
)
)

View File

@ -1697,8 +1697,9 @@ class TestActivationCheckpointing(unittest.TestCase):
"""Neither ----checkpoint-activations nor --offload-activations should change loss"""
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
with self.assertLogs():
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
offload_logs = self._train(data_dir, ["--offload-activations"])
baseline_logs = self._train(data_dir, [])
@ -1720,8 +1721,9 @@ class TestActivationCheckpointing(unittest.TestCase):
"""--checkpoint-activations should not change loss"""
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
with self.assertLogs():
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
ckpt_logs = self._train(data_dir, ["--checkpoint-activations"])
baseline_logs = self._train(data_dir, [])
assert len(baseline_logs) == len(ckpt_logs)

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import unittest
from typing import Sequence
@ -20,6 +21,12 @@ def sample(id: int, length: int):
class TestDataset(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_round_robin_zip_datasets(self):
long_dataset = lang_pair_dataset([10, 9, 8, 11])
short_dataset = lang_pair_dataset([11, 9])