Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667

Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded)

This enables fully parameter + optimizer state sharding by using
FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide
`--ddp-backend=fully_sharded` to enable. Other common options work
out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`,
etc.). This should be a drop-in replacement for the "c10d" backend.

This yields pretty big speedups for small models and enables training ~13B
parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model
parallelism.

This also adds a new option `--cpu-offload` that offloads the optimizer state
and FP32 model copy to CPU, which is particularly useful when combined with
`--optimizer=cpu_adam`.

Note: after enabling this, each GPU will save a checkpoint file, since the
optimizer state is sharded. Each checkpoint will contain a single shard of the
optimizer state and the rank 0 checkpoint will contain the full model weights.

Note: a known limitation of the current implementation is that you cannot
resume training on a different world_size. This constraint will be relaxed in
future iterations.

Test Plan: Imported from OSS

Reviewed By: sshleifer

Differential Revision: D26771144

Pulled By: myleott

fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46
This commit is contained in:
Myle Ott 2021-03-04 13:31:02 -08:00 committed by Facebook GitHub Bot
parent 6d23cc7e7c
commit 656d7e5779
13 changed files with 292 additions and 31 deletions

View File

@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass):
zero_sharding: ZERO_SHARDING_CHOICES = field( zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"} default="none", metadata={"help": "ZeRO sharding"}
) )
fp16: bool = II("common.fp16")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
tpu: bool = II("common.tpu") tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field(
default=False,
metadata={"help": "don't reshard parameters after forward pass"},
)
fp32_reduce_scatter: bool = field(
default=False,
metadata={"help": "reduce-scatter grads in FP32"},
)
cpu_offload: bool = field(
default=False,
metadata={"help": "offload FP32 params to CPU"}
)
@dataclass @dataclass

View File

@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]):
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum([ DDP_BACKEND_CHOICES = ChoiceEnum([
"c10d", # alias for pytorch_ddp "c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"legacy_ddp", "legacy_ddp",
"no_c10d", # alias for legacy_ddp "no_c10d", # alias for legacy_ddp
"pytorch_ddp", "pytorch_ddp",

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel from .tpu_distributed_data_parallel import TPUDistributedDataParallel
@ -11,6 +12,9 @@ from .tpu_distributed_data_parallel import TPUDistributedDataParallel
__all__ = [ __all__ = [
"DistributedTimeoutWrapper", "DistributedTimeoutWrapper",
"fsdp_enable_wrap",
"fsdp_wrap",
"FullyShardedDataParallel",
"LegacyDistributedDataParallel", "LegacyDistributedDataParallel",
"ModuleProxyWrapper", "ModuleProxyWrapper",
"TPUDistributedDataParallel", "TPUDistributedDataParallel",

View File

@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Optional
import torch
from fairseq.dataclass.configs import DistributedTrainingConfig
from fairseq.distributed import utils as dist_utils
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
has_FSDP = True
except ImportError:
FSDP = torch.nn.Module
has_FSDP = False
class FullyShardedDataParallel(FSDP):
"""
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
fairseq-specific checkpoint saving/loading logic.
Args:
use_sharded_state (bool): if True, then ``state_dict`` will return
``FSDP.local_state_dict`` and ``load_state_dict`` will call
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
return the full model weights on data parallel rank 0 (empty on
other ranks) and ``load_state_dict`` will broadcast model weights
from rank 0 to other ranks.
"""
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
super().__init__(*args, **kwargs)
self.use_sharded_state = use_sharded_state
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.use_sharded_state:
return super().local_state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
if self.rank == 0:
return super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
# We must call state_dict() due to use of communication
# primitives. But we don't use the result.
super().state_dict()
return destination or {}
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
if self.use_sharded_state:
return super().load_local_state_dict(state_dict, strict=strict)
else:
state_dict = dist_utils.broadcast_object(
state_dict, src_rank=0, group=self.process_group
)
return super().load_state_dict(state_dict, strict=strict)
@contextlib.contextmanager
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False):
try:
from fairscale.nn import enable_wrap
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if cfg.memory_efficient_fp16:
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
group = dist_utils.get_data_parallel_group()
if group is None and cfg.distributed_world_size == 1:
from fairscale.utils.testing import DummyProcessGroup
group = DummyProcessGroup(rank=0, size=1)
fsdp_config = {
"process_group": group,
"reshard_after_forward": not cfg.no_reshard_after_forward,
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
"flatten_parameters": True,
"cpu_offload": cfg.cpu_offload,
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
"bucket_cap_mb": cfg.bucket_cap_mb,
}
with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
yield
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
fairscale is not available.
Args:
module (nn.Module): module to (maybe) wrap
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap
cls = FullyShardedDataParallel
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, cls=cls, **kwargs)
else:
return module
else:
return wrap(module, cls=cls, **kwargs)
except ImportError:
return module

View File

@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device):
) )
# forward missing getattr and state_dict/load_state_dict to orig model # forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model) wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "fully_sharded":
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP"
wrapped_model = model
if args.memory_efficient_fp16:
wrapped_model = wrapped_model.half()
if not args.cpu_offload:
wrapped_model = wrapped_model.to(device=device)
else: else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
# kill hung distributed jobs after a timeout # kill hung distributed jobs after a timeout
wrapped_model = DistributedTimeoutWrapper( if getattr(args, "heartbeat_timeout", -1) > 0:
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) wrapped_model = DistributedTimeoutWrapper(
) wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
)
return wrapped_model return wrapped_model

View File

@ -27,6 +27,13 @@ from torch import Tensor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(module.unwrapped_module, expected_type)
else:
assert isinstance(module, expected_type)
class BaseFairseqModel(nn.Module): class BaseFairseqModel(nn.Module):
"""Base class for fairseq models.""" """Base class for fairseq models."""
@ -284,8 +291,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder) check_type(self.encoder, FairseqEncoder)
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
""" """
@ -365,8 +373,8 @@ class FairseqMultiModel(BaseFairseqModel):
assert encoders.keys() == decoders.keys() assert encoders.keys() == decoders.keys()
self.keys = list(encoders.keys()) self.keys = list(encoders.keys())
for key in self.keys: for key in self.keys:
assert isinstance(encoders[key], FairseqEncoder) check_type(encoders[key], FairseqEncoder)
assert isinstance(decoders[key], FairseqDecoder) check_type(decoders[key], FairseqDecoder)
self.models = nn.ModuleDict( self.models = nn.ModuleDict(
{ {
@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel):
def __init__(self, decoder): def __init__(self, decoder):
super().__init__() super().__init__()
self.decoder = decoder self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder) check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **kwargs): def forward(self, src_tokens, **kwargs):
""" """
@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel):
def __init__(self, encoder): def __init__(self, encoder):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder) check_type(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs): def forward(self, src_tokens, src_lengths, **kwargs):
""" """

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel, FairseqEncoderDecoderModel,
@ -240,6 +241,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
args.checkpoint_activations = True # offloading implies checkpointing args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
return cls(args, encoder, decoder) return cls(args, encoder, decoder)
@classmethod @classmethod
@ -386,6 +390,7 @@ class TransformerEncoder(FairseqEncoder):
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False) offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer return layer
def forward_embedding( def forward_embedding(
@ -726,6 +731,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if getattr(args, "checkpoint_activations", False): if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False) offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
return layer return layer
def forward( def forward(

View File

@ -107,6 +107,10 @@ class CPUAdam(torch.optim.Optimizer):
self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode
) )
@property
def supports_flat_params(self):
return True
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None

View File

@ -322,6 +322,10 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
def all_reduce_grads(self, module): def all_reduce_grads(self, module):
self.fp32_optimizer.all_reduce_grads(module) self.fp32_optimizer.all_reduce_grads(module)
@property
def supports_flat_params(self):
return self.fp32_optimizer.supports_flat_params
class _MemoryEfficientFP16OptimizerMixin(object): class _MemoryEfficientFP16OptimizerMixin(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -442,6 +446,10 @@ class _MemoryEfficientFP16OptimizerMixin(object):
else: else:
self._multiply_factor = 1.0 self._multiply_factor = 1.0
@property
def supports_flat_params(self):
return self.wrapped_optimizer.supports_flat_params
class MemoryEfficientFP16Optimizer( class MemoryEfficientFP16Optimizer(
_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer(
*supports_memory_efficient_fp16* property. *supports_memory_efficient_fp16* property.
""" """
def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): def __init__(
if not optimizer.supports_memory_efficient_fp16: self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs
):
if not allow_unsupported and not optimizer.supports_memory_efficient_fp16:
raise ValueError( raise ValueError(
"Unsupported optimizer: {}".format(optimizer.__class__.__name__) "Unsupported optimizer: {}".format(optimizer.__class__.__name__)
) )

View File

@ -63,15 +63,31 @@ class Trainer(object):
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
if self.cfg.common.bf16:
raise ValueError(
"FullyShardedDataParallel is not compatible with --bf16 or "
"--memory-efficient-bf16"
)
if self.cfg.distributed_training.zero_sharding != "none":
raise ValueError(
"FullyShardedDataParallel is not compatible with --zero-sharding "
"option (it's already built in)"
)
else:
if self.cfg.distributed_training.cpu_offload:
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
# copy model and criterion to current device/dtype # copy model and criterion to current device/dtype
self._criterion = criterion self._criterion = criterion
self._model = model self._model = model
if cfg.common.fp16: if cfg.distributed_training.ddp_backend != "fully_sharded":
self._criterion = self._criterion.half() if cfg.common.fp16:
self._model = self._model.half() self._criterion = self._criterion.half()
elif cfg.common.bf16: self._model = self._model.half()
self._criterion = self._criterion.to(dtype=torch.bfloat16) elif cfg.common.bf16:
self._model = self._model.to(dtype=torch.bfloat16) self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
if ( if (
not cfg.distributed_training.pipeline_model_parallel not cfg.distributed_training.pipeline_model_parallel
# the DistributedFairseqModel wrapper will handle moving to device, # the DistributedFairseqModel wrapper will handle moving to device,
@ -171,17 +187,26 @@ class Trainer(object):
return ( return (
self.data_parallel_world_size > 1 self.data_parallel_world_size > 1
and not self.cfg.optimization.use_bmuf and not self.cfg.optimization.use_bmuf
) or (
self.cfg.distributed_training.ddp_backend == "fully_sharded"
and self.cfg.distributed_training.cpu_offload
) )
@property @property
def should_save_checkpoint_on_current_rank(self) -> bool: def should_save_checkpoint_on_current_rank(self) -> bool:
"""Indicates whether to save checkpoints on the current DDP rank.""" """Indicates whether to save checkpoints on the current DDP rank."""
return self.is_data_parallel_master if self.cfg.distributed_training.ddp_backend == "fully_sharded":
return True
else:
return self.is_data_parallel_master
@property @property
def checkpoint_suffix(self) -> str: def checkpoint_suffix(self) -> str:
"""Suffix to add to the checkpoint file name.""" """Suffix to add to the checkpoint file name."""
return self.cfg.checkpoint.checkpoint_suffix or "" if self.cfg.distributed_training.ddp_backend == "fully_sharded":
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank)
else:
return self.cfg.checkpoint.checkpoint_suffix or ""
@property @property
def criterion(self): def criterion(self):
@ -234,7 +259,20 @@ class Trainer(object):
) )
) )
if self.cfg.common.fp16 or self.cfg.common.bf16: if (
self.cfg.distributed_training.ddp_backend == "fully_sharded"
and self.cfg.common.fp16
):
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
# mostly for the grad scaling. But if we don't have the
# --memory-efficient-fp16 flag set, then we're effectively doing
# regular --fp16 and can allow the use of optimizers that would
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
allow_unsupported = not self.cfg.common.memory_efficient_fp16
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.cfg, params, allow_unsupported=allow_unsupported
)
elif self.cfg.common.fp16 or self.cfg.common.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info( logger.info(
"NOTE: your device does NOT support faster training with --fp16, " "NOTE: your device does NOT support faster training with --fp16, "
@ -254,6 +292,16 @@ class Trainer(object):
logger.info("NOTE: your device may support faster training with --fp16") logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
assert not self.cfg.optimization.use_bmuf, \
"--ddp-backend=fully_sharded is not compatible with BMUF"
assert self._optimizer.supports_flat_params, (
"--ddp-backend=fully_sharded is only compatible with pointwise "
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
"However, the sharding will result in slightly different results when "
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
)
if self.cfg.optimization.use_bmuf: if self.cfg.optimization.use_bmuf:
self._optimizer = optim.FairseqBMUF( self._optimizer = optim.FairseqBMUF(
self.cfg.bmuf, self.cfg.bmuf,
@ -355,6 +403,8 @@ class Trainer(object):
# TPUs don't support broadcast yet, so load checkpoints # TPUs don't support broadcast yet, so load checkpoints
# on every worker for now # on every worker for now
or self.tpu or self.tpu
# FSDP requires loading checkpoint shards on all ranks
or self.cfg.distributed_training.ddp_backend == "fully_sharded"
) )
if load_on_all_ranks or self.data_parallel_rank == 0: if load_on_all_ranks or self.data_parallel_rank == 0:
@ -965,7 +1015,21 @@ class Trainer(object):
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm): def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None)
def agg_norm_fn(total_norm):
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
total_norm = total_norm ** 2
if (
self.data_parallel_process_group is not None
or torch.distributed.is_initialized()
):
total_norm = distributed_utils.all_reduce(
total_norm.cuda(), group=self.data_parallel_process_group
)
total_norm = total_norm ** 0.5
return total_norm
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn)
def cumulative_training_time(self): def cumulative_training_time(self):
if self._cumulative_training_time is None: if self._cumulative_training_time is None:

View File

@ -18,7 +18,6 @@ import numpy as np
import torch import torch
from fairseq import ( from fairseq import (
checkpoint_utils, checkpoint_utils,
distributed_utils,
options, options,
quantization_utils, quantization_utils,
tasks, tasks,
@ -27,7 +26,7 @@ from fairseq import (
from fairseq.data import iterators from fairseq.data import iterators
from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed_utils import is_master from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from fairseq.file_io import PathManager from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics, progress_bar from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.model_parallel.megatron_trainer import MegatronTrainer
@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None:
utils.import_user_module(cfg.common) utils.import_user_module(cfg.common)
if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None:
assert cfg.criterion, "Please specify criterion to train a model" assert cfg.criterion, "Please specify criterion to train a model"
# Build model and criterion # Build model and criterion
model = task.build_model(cfg.model) if cfg.distributed_training.ddp_backend == "fully_sharded":
with fsdp_enable_wrap(cfg.distributed_training):
model = fsdp_wrap(task.build_model(cfg.model))
else:
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion) criterion = task.build_criterion(cfg.criterion)
logger.info(model) logger.info(model)
logger.info("task: {}".format(task.__class__.__name__)) logger.info("task: {}".format(task.__class__.__name__))
@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None:
logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info( logger.info(
"num. model params: {:,} (num. trained: {:,})".format( "num. model params: {:,} (num. trained: {:,})".format(
sum(p.numel() for p in model.parameters()), sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad), sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad),
) )
) )

View File

@ -1697,8 +1697,9 @@ class TestActivationCheckpointing(unittest.TestCase):
"""Neither ----checkpoint-activations nor --offload-activations should change loss""" """Neither ----checkpoint-activations nor --offload-activations should change loss"""
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20) with self.assertLogs():
preprocess_translation_data(data_dir) create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
offload_logs = self._train(data_dir, ["--offload-activations"]) offload_logs = self._train(data_dir, ["--offload-activations"])
baseline_logs = self._train(data_dir, []) baseline_logs = self._train(data_dir, [])
@ -1720,8 +1721,9 @@ class TestActivationCheckpointing(unittest.TestCase):
"""--checkpoint-activations should not change loss""" """--checkpoint-activations should not change loss"""
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20) with self.assertLogs():
preprocess_translation_data(data_dir) create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) ckpt_logs = self._train(data_dir, ["--checkpoint-activations"])
baseline_logs = self._train(data_dir, []) baseline_logs = self._train(data_dir, [])
assert len(baseline_logs) == len(ckpt_logs) assert len(baseline_logs) == len(ckpt_logs)

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import unittest import unittest
from typing import Sequence from typing import Sequence
@ -20,6 +21,12 @@ def sample(id: int, length: int):
class TestDataset(unittest.TestCase): class TestDataset(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_round_robin_zip_datasets(self): def test_round_robin_zip_datasets(self):
long_dataset = lang_pair_dataset([10, 9, 8, 11]) long_dataset = lang_pair_dataset([10, 9, 8, 11])
short_dataset = lang_pair_dataset([11, 9]) short_dataset = lang_pair_dataset([11, 9])