From 656d7e5779a9ec4ccf0ad45d86a4ce589c597588 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH] Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667 Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) This enables fully parameter + optimizer state sharding by using FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide `--ddp-backend=fully_sharded` to enable. Other common options work out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`, etc.). This should be a drop-in replacement for the "c10d" backend. This yields pretty big speedups for small models and enables training ~13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model parallelism. This also adds a new option `--cpu-offload` that offloads the optimizer state and FP32 model copy to CPU, which is particularly useful when combined with `--optimizer=cpu_adam`. Note: after enabling this, each GPU will save a checkpoint file, since the optimizer state is sharded. Each checkpoint will contain a single shard of the optimizer state and the rank 0 checkpoint will contain the full model weights. Note: a known limitation of the current implementation is that you cannot resume training on a different world_size. This constraint will be relaxed in future iterations. Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26771144 Pulled By: myleott fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46 --- fairseq/dataclass/configs.py | 15 +++ fairseq/dataclass/constants.py | 1 + fairseq/distributed/__init__.py | 4 + .../fully_sharded_data_parallel.py | 122 ++++++++++++++++++ fairseq/models/distributed_fairseq_model.py | 21 ++- fairseq/models/fairseq_model.py | 20 ++- fairseq/models/transformer.py | 6 + fairseq/optim/cpu_adam.py | 4 + fairseq/optim/fp16_optimizer.py | 14 +- fairseq/trainer.py | 84 ++++++++++-- fairseq_cli/train.py | 15 ++- tests/test_binaries.py | 10 +- tests/test_dataset.py | 7 + 13 files changed, 292 insertions(+), 31 deletions(-) create mode 100644 fairseq/distributed/fully_sharded_data_parallel.py diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 4d3c60bf..5d6aee15 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -355,7 +355,22 @@ class DistributedTrainingConfig(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, + metadata={"help": "offload FP32 params to CPU"} + ) @dataclass diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 93bc6d03..faba0862 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum([ "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale "legacy_ddp", "no_c10d", # alias for legacy_ddp "pytorch_ddp", diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index 7f4016e3..d0b96b73 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .module_proxy_wrapper import ModuleProxyWrapper from .tpu_distributed_data_parallel import TPUDistributedDataParallel @@ -11,6 +12,9 @@ from .tpu_distributed_data_parallel import TPUDistributedDataParallel __all__ = [ "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", "LegacyDistributedDataParallel", "ModuleProxyWrapper", "TPUDistributedDataParallel", diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 00000000..9d743983 --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch + +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + from fairscale.utils.testing import DummyProcessGroup + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": True, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + } + with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + cls = FullyShardedDataParallel + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, cls=cls, **kwargs) + else: + return module + else: + return wrap(module, cls=cls, **kwargs) + except ImportError: + return module diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ca157f06..3422faea 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) # kill hung distributed jobs after a timeout - wrapped_model = DistributedTimeoutWrapper( - wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) - ) + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) return wrapped_model diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 186f3d24..d393c02a 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -27,6 +27,13 @@ from torch import Tensor logger = logging.getLogger(__name__) +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance(module.unwrapped_module, expected_type) + else: + assert isinstance(module, expected_type) + + class BaseFairseqModel(nn.Module): """Base class for fairseq models.""" @@ -284,8 +291,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel): self.encoder = encoder self.decoder = decoder - assert isinstance(self.encoder, FairseqEncoder) - assert isinstance(self.decoder, FairseqDecoder) + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ @@ -365,8 +373,8 @@ class FairseqMultiModel(BaseFairseqModel): assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) for key in self.keys: - assert isinstance(encoders[key], FairseqEncoder) - assert isinstance(decoders[key], FairseqDecoder) + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) self.models = nn.ModuleDict( { @@ -469,7 +477,7 @@ class FairseqLanguageModel(BaseFairseqModel): def __init__(self, decoder): super().__init__() self.decoder = decoder - assert isinstance(self.decoder, FairseqDecoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, **kwargs): """ @@ -530,7 +538,7 @@ class FairseqEncoderModel(BaseFairseqModel): def __init__(self, encoder): super().__init__() self.encoder = encoder - assert isinstance(self.encoder, FairseqEncoder) + check_type(self.encoder, FairseqEncoder) def forward(self, src_tokens, src_lengths, **kwargs): """ diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f2f36baf..a0a0b8dc 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn from fairseq import utils +from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -240,6 +241,9 @@ class TransformerModel(FairseqEncoderDecoderModel): args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: + encoder = fsdp_wrap(encoder, min_num_params=1e8) + decoder = fsdp_wrap(decoder, min_num_params=1e8) return cls(args, encoder, decoder) @classmethod @@ -386,6 +390,7 @@ class TransformerEncoder(FairseqEncoder): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward_embedding( @@ -726,6 +731,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): if getattr(args, "checkpoint_activations", False): offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + layer = fsdp_wrap(layer, min_num_params=1e8) return layer def forward( diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index fad5a64e..5e935df1 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -107,6 +107,10 @@ class CPUAdam(torch.optim.Optimizer): self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_flat_params(self): + return True + @torch.no_grad() def step(self, closure=None): loss = None diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index e0b069f1..00ea1bbb 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -322,6 +322,10 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): def all_reduce_grads(self, module): self.fp32_optimizer.all_reduce_grads(module) + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -442,6 +446,10 @@ class _MemoryEfficientFP16OptimizerMixin(object): else: self._multiply_factor = 1.0 + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + class MemoryEfficientFP16Optimizer( _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer @@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): - if not optimizer.supports_memory_efficient_fp16: + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 45d9591d..4d47d398 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -63,15 +63,31 @@ class Trainer(object): else: self.device = torch.device("cpu") + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + else: + if self.cfg.distributed_training.cpu_offload: + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.common.fp16: - self._criterion = self._criterion.half() - self._model = self._model.half() - elif cfg.common.bf16: - self._criterion = self._criterion.to(dtype=torch.bfloat16) - self._model = self._model.to(dtype=torch.bfloat16) + if cfg.distributed_training.ddp_backend != "fully_sharded": + if cfg.common.fp16: + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, @@ -171,17 +187,26 @@ class Trainer(object): return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.cpu_offload ) @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - return self.is_data_parallel_master + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return True + else: + return self.is_data_parallel_master @property def checkpoint_suffix(self) -> str: """Suffix to add to the checkpoint file name.""" - return self.cfg.checkpoint.checkpoint_suffix or "" + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" @property def criterion(self): @@ -234,7 +259,20 @@ class Trainer(object): ) ) - if self.cfg.common.fp16 or self.cfg.common.bf16: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.common.fp16 + ): + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " @@ -254,6 +292,16 @@ class Trainer(object): logger.info("NOTE: your device may support faster training with --fp16") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + assert not self.cfg.optimization.use_bmuf, \ + "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + if self.cfg.optimization.use_bmuf: self._optimizer = optim.FairseqBMUF( self.cfg.bmuf, @@ -355,6 +403,8 @@ class Trainer(object): # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or self.cfg.distributed_training.ddp_backend == "fully_sharded" ) if load_on_all_ranks or self.data_parallel_rank == 0: @@ -965,7 +1015,21 @@ class Trainer(object): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + + def agg_norm_fn(total_norm): + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + total_norm = total_norm ** 2 + if ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ): + total_norm = distributed_utils.all_reduce( + total_norm.cuda(), group=self.data_parallel_process_group + ) + total_norm = total_norm ** 0.5 + return total_norm + + return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=agg_norm_fn) def cumulative_training_time(self): if self._cumulative_training_time is None: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 80ad57ac..d770e4e4 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -18,7 +18,6 @@ import numpy as np import torch from fairseq import ( checkpoint_utils, - distributed_utils, options, quantization_utils, tasks, @@ -27,7 +26,7 @@ from fairseq import ( from fairseq.data import iterators from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.distributed_utils import is_master +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer @@ -50,7 +49,7 @@ def main(cfg: FairseqConfig) -> None: utils.import_user_module(cfg.common) - if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) @@ -87,7 +86,11 @@ def main(cfg: FairseqConfig) -> None: assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion - model = task.build_model(cfg.model) + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) @@ -95,8 +98,8 @@ def main(cfg: FairseqConfig) -> None: logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {:,} (num. trained: {:,})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), ) ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 3cb98897..e10cc767 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1697,8 +1697,9 @@ class TestActivationCheckpointing(unittest.TestCase): """Neither ----checkpoint-activations nor --offload-activations should change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) offload_logs = self._train(data_dir, ["--offload-activations"]) baseline_logs = self._train(data_dir, []) @@ -1720,8 +1721,9 @@ class TestActivationCheckpointing(unittest.TestCase): """--checkpoint-activations should not change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) baseline_logs = self._train(data_dir, []) assert len(baseline_logs) == len(ckpt_logs) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9fb69a5f..a3e39700 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from typing import Sequence @@ -20,6 +21,12 @@ def sample(id: int, length: int): class TestDataset(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + def test_round_robin_zip_datasets(self): long_dataset = lang_pair_dataset([10, 9, 8, 11]) short_dataset = lang_pair_dataset([11, 9])