mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667 Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) This enables fully parameter + optimizer state sharding by using FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide `--ddp-backend=fully_sharded` to enable. Other common options work out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`, etc.). This should be a drop-in replacement for the "c10d" backend. This yields pretty big speedups for small models and enables training ~13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model parallelism. This also adds a new option `--cpu-offload` that offloads the optimizer state and FP32 model copy to CPU, which is particularly useful when combined with `--optimizer=cpu_adam`. Note: after enabling this, each GPU will save a checkpoint file, since the optimizer state is sharded. Each checkpoint will contain a single shard of the optimizer state and the rank 0 checkpoint will contain the full model weights. Note: a known limitation of the current implementation is that you cannot resume training on a different world_size. This constraint will be relaxed in future iterations. Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26771144 Pulled By: myleott fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46
This commit is contained in:
parent
6d23cc7e7c
commit
656d7e5779
@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass):
|
||||
zero_sharding: ZERO_SHARDING_CHOICES = field(
|
||||
default="none", metadata={"help": "ZeRO sharding"}
|
||||
)
|
||||
fp16: bool = II("common.fp16")
|
||||
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
|
||||
tpu: bool = II("common.tpu")
|
||||
# configuration for --ddp-backend=fully_sharded
|
||||
no_reshard_after_forward: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't reshard parameters after forward pass"},
|
||||
)
|
||||
fp32_reduce_scatter: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "reduce-scatter grads in FP32"},
|
||||
)
|
||||
cpu_offload: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "offload FP32 params to CPU"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]):
|
||||
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
|
||||
DDP_BACKEND_CHOICES = ChoiceEnum([
|
||||
"c10d", # alias for pytorch_ddp
|
||||
"fully_sharded", # FullyShardedDataParallel from fairscale
|
||||
"legacy_ddp",
|
||||
"no_c10d", # alias for legacy_ddp
|
||||
"pytorch_ddp",
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
|
||||
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel
|
||||
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
||||
from .module_proxy_wrapper import ModuleProxyWrapper
|
||||
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
|
||||
@ -11,6 +12,9 @@ from .tpu_distributed_data_parallel import TPUDistributedDataParallel
|
||||
|
||||
__all__ = [
|
||||
"DistributedTimeoutWrapper",
|
||||
"fsdp_enable_wrap",
|
||||
"fsdp_wrap",
|
||||
"FullyShardedDataParallel",
|
||||
"LegacyDistributedDataParallel",
|
||||
"ModuleProxyWrapper",
|
||||
"TPUDistributedDataParallel",
|
||||
|
122
fairseq/distributed/fully_sharded_data_parallel.py
Normal file
122
fairseq/distributed/fully_sharded_data_parallel.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.dataclass.configs import DistributedTrainingConfig
|
||||
from fairseq.distributed import utils as dist_utils
|
||||
|
||||
|
||||
try:
|
||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
||||
has_FSDP = True
|
||||
except ImportError:
|
||||
FSDP = torch.nn.Module
|
||||
has_FSDP = False
|
||||
|
||||
|
||||
class FullyShardedDataParallel(FSDP):
|
||||
"""
|
||||
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
|
||||
fairseq-specific checkpoint saving/loading logic.
|
||||
|
||||
Args:
|
||||
use_sharded_state (bool): if True, then ``state_dict`` will return
|
||||
``FSDP.local_state_dict`` and ``load_state_dict`` will call
|
||||
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
|
||||
return the full model weights on data parallel rank 0 (empty on
|
||||
other ranks) and ``load_state_dict`` will broadcast model weights
|
||||
from rank 0 to other ranks.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
|
||||
if not has_FSDP:
|
||||
raise ImportError(
|
||||
"Cannot find FullyShardedDataParallel. "
|
||||
"Please install fairscale with: pip install fairscale"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_sharded_state = use_sharded_state
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.use_sharded_state:
|
||||
return super().local_state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
return super().state_dict(
|
||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
else:
|
||||
# We must call state_dict() due to use of communication
|
||||
# primitives. But we don't use the result.
|
||||
super().state_dict()
|
||||
return destination or {}
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
||||
if self.use_sharded_state:
|
||||
return super().load_local_state_dict(state_dict, strict=strict)
|
||||
else:
|
||||
state_dict = dist_utils.broadcast_object(
|
||||
state_dict, src_rank=0, group=self.process_group
|
||||
)
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False):
|
||||
try:
|
||||
from fairscale.nn import enable_wrap
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Cannot find FullyShardedDataParallel. "
|
||||
"Please install fairscale with: pip install fairscale"
|
||||
)
|
||||
if cfg.memory_efficient_fp16:
|
||||
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
|
||||
group = dist_utils.get_data_parallel_group()
|
||||
if group is None and cfg.distributed_world_size == 1:
|
||||
from fairscale.utils.testing import DummyProcessGroup
|
||||
group = DummyProcessGroup(rank=0, size=1)
|
||||
fsdp_config = {
|
||||
"process_group": group,
|
||||
"reshard_after_forward": not cfg.no_reshard_after_forward,
|
||||
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
|
||||
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
|
||||
"flatten_parameters": True,
|
||||
"cpu_offload": cfg.cpu_offload,
|
||||
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
|
||||
"bucket_cap_mb": cfg.bucket_cap_mb,
|
||||
}
|
||||
with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
|
||||
yield
|
||||
|
||||
|
||||
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
|
||||
fairscale is not available.
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to (maybe) wrap
|
||||
min_num_params (int, Optional): minimum number of layer params to wrap
|
||||
"""
|
||||
try:
|
||||
from fairscale.nn import wrap
|
||||
cls = FullyShardedDataParallel
|
||||
if min_num_params is not None:
|
||||
num_params = sum(p.numel() for p in module.parameters())
|
||||
if num_params >= min_num_params:
|
||||
return wrap(module, cls=cls, **kwargs)
|
||||
else:
|
||||
return module
|
||||
else:
|
||||
return wrap(module, cls=cls, **kwargs)
|
||||
except ImportError:
|
||||
return module
|
@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device):
|
||||
)
|
||||
# forward missing getattr and state_dict/load_state_dict to orig model
|
||||
wrapped_model = ModuleProxyWrapper(wrapped_model)
|
||||
elif args.ddp_backend == "fully_sharded":
|
||||
try:
|
||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Cannot find FullyShardedDataParallel. "
|
||||
"Please install fairscale with: pip install fairscale"
|
||||
)
|
||||
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP"
|
||||
wrapped_model = model
|
||||
if args.memory_efficient_fp16:
|
||||
wrapped_model = wrapped_model.half()
|
||||
if not args.cpu_offload:
|
||||
wrapped_model = wrapped_model.to(device=device)
|
||||
else:
|
||||
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
|
||||
|
||||
# kill hung distributed jobs after a timeout
|
||||
wrapped_model = DistributedTimeoutWrapper(
|
||||
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
|
||||
)
|
||||
if getattr(args, "heartbeat_timeout", -1) > 0:
|
||||
wrapped_model = DistributedTimeoutWrapper(
|
||||
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1)
|
||||
)
|
||||
|
||||
return wrapped_model
|
||||
|
@ -27,6 +27,13 @@ from torch import Tensor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_type(module, expected_type):
|
||||
if hasattr(module, "unwrapped_module"):
|
||||
assert isinstance(module.unwrapped_module, expected_type)
|
||||
else:
|
||||
assert isinstance(module, expected_type)
|
||||
|
||||
|
||||
class BaseFairseqModel(nn.Module):
|
||||
"""Base class for fairseq models."""
|
||||
|
||||
@ -284,8 +291,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel):
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
assert isinstance(self.encoder, FairseqEncoder)
|
||||
assert isinstance(self.decoder, FairseqDecoder)
|
||||
|
||||
check_type(self.encoder, FairseqEncoder)
|
||||
check_type(self.decoder, FairseqDecoder)
|
||||
|
||||
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
||||
"""
|
||||
@ -365,8 +373,8 @@ class FairseqMultiModel(BaseFairseqModel):
|
||||
assert encoders.keys() == decoders.keys()
|
||||
self.keys = list(encoders.keys())
|
||||
for key in self.keys:
|
||||
assert isinstance(encoders[key], FairseqEncoder)
|
||||
assert isinstance(decoders[key], FairseqDecoder)
|
||||
check_type(encoders[key], FairseqEncoder)
|
||||
check_type(decoders[key], FairseqDecoder)
|
||||
|
||||
self.models = nn.ModuleDict(
|
||||
{
|
||||
@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel):
|
||||
def __init__(self, decoder):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
assert isinstance(self.decoder, FairseqDecoder)
|
||||
check_type(self.decoder, FairseqDecoder)
|
||||
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
"""
|
||||
@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel):
|
||||
def __init__(self, encoder):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
assert isinstance(self.encoder, FairseqEncoder)
|
||||
check_type(self.encoder, FairseqEncoder)
|
||||
|
||||
def forward(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import fsdp_wrap
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
@ -240,6 +241,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
|
||||
args.checkpoint_activations = True # offloading implies checkpointing
|
||||
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
||||
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
||||
if not args.share_all_embeddings:
|
||||
encoder = fsdp_wrap(encoder, min_num_params=1e8)
|
||||
decoder = fsdp_wrap(decoder, min_num_params=1e8)
|
||||
return cls(args, encoder, decoder)
|
||||
|
||||
@classmethod
|
||||
@ -386,6 +390,7 @@ class TransformerEncoder(FairseqEncoder):
|
||||
if getattr(args, "checkpoint_activations", False):
|
||||
offload_to_cpu = getattr(args, "offload_activations", False)
|
||||
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
||||
layer = fsdp_wrap(layer, min_num_params=1e8)
|
||||
return layer
|
||||
|
||||
def forward_embedding(
|
||||
@ -726,6 +731,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
if getattr(args, "checkpoint_activations", False):
|
||||
offload_to_cpu = getattr(args, "offload_activations", False)
|
||||
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
||||
layer = fsdp_wrap(layer, min_num_params=1e8)
|
||||
return layer
|
||||
|
||||
def forward(
|
||||
|
@ -107,6 +107,10 @@ class CPUAdam(torch.optim.Optimizer):
|
||||
self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_flat_params(self):
|
||||
return True
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
|
@ -322,6 +322,10 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
|
||||
def all_reduce_grads(self, module):
|
||||
self.fp32_optimizer.all_reduce_grads(module)
|
||||
|
||||
@property
|
||||
def supports_flat_params(self):
|
||||
return self.fp32_optimizer.supports_flat_params
|
||||
|
||||
|
||||
class _MemoryEfficientFP16OptimizerMixin(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -442,6 +446,10 @@ class _MemoryEfficientFP16OptimizerMixin(object):
|
||||
else:
|
||||
self._multiply_factor = 1.0
|
||||
|
||||
@property
|
||||
def supports_flat_params(self):
|
||||
return self.wrapped_optimizer.supports_flat_params
|
||||
|
||||
|
||||
class MemoryEfficientFP16Optimizer(
|
||||
_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
|
||||
@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer(
|
||||
*supports_memory_efficient_fp16* property.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DictConfig, params, optimizer, **kwargs):
|
||||
if not optimizer.supports_memory_efficient_fp16:
|
||||
def __init__(
|
||||
self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs
|
||||
):
|
||||
if not allow_unsupported and not optimizer.supports_memory_efficient_fp16:
|
||||
raise ValueError(
|
||||
"Unsupported optimizer: {}".format(optimizer.__class__.__name__)
|
||||
)
|
||||
|
@ -63,15 +63,31 @@ class Trainer(object):
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
if self.cfg.common.bf16:
|
||||
raise ValueError(
|
||||
"FullyShardedDataParallel is not compatible with --bf16 or "
|
||||
"--memory-efficient-bf16"
|
||||
)
|
||||
if self.cfg.distributed_training.zero_sharding != "none":
|
||||
raise ValueError(
|
||||
"FullyShardedDataParallel is not compatible with --zero-sharding "
|
||||
"option (it's already built in)"
|
||||
)
|
||||
else:
|
||||
if self.cfg.distributed_training.cpu_offload:
|
||||
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
|
||||
|
||||
# copy model and criterion to current device/dtype
|
||||
self._criterion = criterion
|
||||
self._model = model
|
||||
if cfg.common.fp16:
|
||||
self._criterion = self._criterion.half()
|
||||
self._model = self._model.half()
|
||||
elif cfg.common.bf16:
|
||||
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
||||
self._model = self._model.to(dtype=torch.bfloat16)
|
||||
if cfg.distributed_training.ddp_backend != "fully_sharded":
|
||||
if cfg.common.fp16:
|
||||
self._criterion = self._criterion.half()
|
||||
self._model = self._model.half()
|
||||
elif cfg.common.bf16:
|
||||
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
||||
self._model = self._model.to(dtype=torch.bfloat16)
|
||||
if (
|
||||
not cfg.distributed_training.pipeline_model_parallel
|
||||
# the DistributedFairseqModel wrapper will handle moving to device,
|
||||
@ -171,17 +187,26 @@ class Trainer(object):
|
||||
return (
|
||||
self.data_parallel_world_size > 1
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
) or (
|
||||
self.cfg.distributed_training.ddp_backend == "fully_sharded"
|
||||
and self.cfg.distributed_training.cpu_offload
|
||||
)
|
||||
|
||||
@property
|
||||
def should_save_checkpoint_on_current_rank(self) -> bool:
|
||||
"""Indicates whether to save checkpoints on the current DDP rank."""
|
||||
return self.is_data_parallel_master
|
||||
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
return True
|
||||
else:
|
||||
return self.is_data_parallel_master
|
||||
|
||||
@property
|
||||
def checkpoint_suffix(self) -> str:
|
||||
"""Suffix to add to the checkpoint file name."""
|
||||
return self.cfg.checkpoint.checkpoint_suffix or ""
|
||||
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank)
|
||||
else:
|
||||
return self.cfg.checkpoint.checkpoint_suffix or ""
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
@ -234,7 +259,20 @@ class Trainer(object):
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.common.fp16 or self.cfg.common.bf16:
|
||||
if (
|
||||
self.cfg.distributed_training.ddp_backend == "fully_sharded"
|
||||
and self.cfg.common.fp16
|
||||
):
|
||||
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
|
||||
# mostly for the grad scaling. But if we don't have the
|
||||
# --memory-efficient-fp16 flag set, then we're effectively doing
|
||||
# regular --fp16 and can allow the use of optimizers that would
|
||||
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
|
||||
allow_unsupported = not self.cfg.common.memory_efficient_fp16
|
||||
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
||||
self.cfg, params, allow_unsupported=allow_unsupported
|
||||
)
|
||||
elif self.cfg.common.fp16 or self.cfg.common.bf16:
|
||||
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
|
||||
logger.info(
|
||||
"NOTE: your device does NOT support faster training with --fp16, "
|
||||
@ -254,6 +292,16 @@ class Trainer(object):
|
||||
logger.info("NOTE: your device may support faster training with --fp16")
|
||||
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
|
||||
|
||||
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
assert not self.cfg.optimization.use_bmuf, \
|
||||
"--ddp-backend=fully_sharded is not compatible with BMUF"
|
||||
assert self._optimizer.supports_flat_params, (
|
||||
"--ddp-backend=fully_sharded is only compatible with pointwise "
|
||||
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
|
||||
"However, the sharding will result in slightly different results when "
|
||||
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
|
||||
)
|
||||
|
||||
if self.cfg.optimization.use_bmuf:
|
||||
self._optimizer = optim.FairseqBMUF(
|
||||
self.cfg.bmuf,
|
||||
@ -355,6 +403,8 @@ class Trainer(object):
|
||||
# TPUs don't support broadcast yet, so load checkpoints
|
||||
# on every worker for now
|
||||
or self.tpu
|
||||
# FSDP requires loading checkpoint shards on all ranks
|
||||
or self.cfg.distributed_training.ddp_backend == "fully_sharded"
|
||||
)
|
||||
|
||||
if load_on_all_ranks or self.data_parallel_rank == 0:
|
||||
@ -965,7 +1015,21 @@ class Trainer(object):
|
||||
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
|
||||
|
||||
def clip_grad_norm(self, clip_norm):
|
||||
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None)
|
||||
|
||||
def agg_norm_fn(total_norm):
|
||||
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
total_norm = total_norm ** 2
|
||||
if (
|
||||
self.data_parallel_process_group is not None
|
||||
or torch.distributed.is_initialized()
|
||||
):
|
||||
total_norm = distributed_utils.all_reduce(
|
||||
total_norm.cuda(), group=self.data_parallel_process_group
|
||||
)
|
||||
total_norm = total_norm ** 0.5
|
||||
return total_norm
|
||||
|
||||
return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn)
|
||||
|
||||
def cumulative_training_time(self):
|
||||
if self._cumulative_training_time is None:
|
||||
|
@ -18,7 +18,6 @@ import numpy as np
|
||||
import torch
|
||||
from fairseq import (
|
||||
checkpoint_utils,
|
||||
distributed_utils,
|
||||
options,
|
||||
quantization_utils,
|
||||
tasks,
|
||||
@ -27,7 +26,7 @@ from fairseq import (
|
||||
from fairseq.data import iterators
|
||||
from fairseq.dataclass.configs import FairseqConfig
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.distributed_utils import is_master
|
||||
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.logging import meters, metrics, progress_bar
|
||||
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
||||
@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None:
|
||||
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
|
||||
if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
|
||||
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
||||
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
||||
|
||||
@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None:
|
||||
assert cfg.criterion, "Please specify criterion to train a model"
|
||||
|
||||
# Build model and criterion
|
||||
model = task.build_model(cfg.model)
|
||||
if cfg.distributed_training.ddp_backend == "fully_sharded":
|
||||
with fsdp_enable_wrap(cfg.distributed_training):
|
||||
model = fsdp_wrap(task.build_model(cfg.model))
|
||||
else:
|
||||
model = task.build_model(cfg.model)
|
||||
criterion = task.build_criterion(cfg.criterion)
|
||||
logger.info(model)
|
||||
logger.info("task: {}".format(task.__class__.__name__))
|
||||
@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None:
|
||||
logger.info("criterion: {}".format(criterion.__class__.__name__))
|
||||
logger.info(
|
||||
"num. model params: {:,} (num. trained: {:,})".format(
|
||||
sum(p.numel() for p in model.parameters()),
|
||||
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()),
|
||||
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1697,8 +1697,9 @@ class TestActivationCheckpointing(unittest.TestCase):
|
||||
"""Neither ----checkpoint-activations nor --offload-activations should change loss"""
|
||||
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
|
||||
|
||||
create_dummy_data(data_dir, num_examples=20)
|
||||
preprocess_translation_data(data_dir)
|
||||
with self.assertLogs():
|
||||
create_dummy_data(data_dir, num_examples=20)
|
||||
preprocess_translation_data(data_dir)
|
||||
offload_logs = self._train(data_dir, ["--offload-activations"])
|
||||
baseline_logs = self._train(data_dir, [])
|
||||
|
||||
@ -1720,8 +1721,9 @@ class TestActivationCheckpointing(unittest.TestCase):
|
||||
"""--checkpoint-activations should not change loss"""
|
||||
|
||||
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
|
||||
create_dummy_data(data_dir, num_examples=20)
|
||||
preprocess_translation_data(data_dir)
|
||||
with self.assertLogs():
|
||||
create_dummy_data(data_dir, num_examples=20)
|
||||
preprocess_translation_data(data_dir)
|
||||
ckpt_logs = self._train(data_dir, ["--checkpoint-activations"])
|
||||
baseline_logs = self._train(data_dir, [])
|
||||
assert len(baseline_logs) == len(ckpt_logs)
|
||||
|
@ -3,6 +3,7 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Sequence
|
||||
|
||||
@ -20,6 +21,12 @@ def sample(id: int, length: int):
|
||||
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
def setUp(self):
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
def tearDown(self):
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
def test_round_robin_zip_datasets(self):
|
||||
long_dataset = lang_pair_dataset([10, 9, 8, 11])
|
||||
short_dataset = lang_pair_dataset([11, 9])
|
||||
|
Loading…
Reference in New Issue
Block a user