From 3a5838c320c5b7afc3a6fba5736bca22503ef804 Mon Sep 17 00:00:00 2001 From: Vinayak Tantia Date: Tue, 9 Nov 2021 09:41:21 -0800 Subject: [PATCH] Update implemention of SlowMo to its implementation in Fairscale (#3996) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? SlowMo is being moved to [Fairscale](https://fairscale.readthedocs.io/en/latest/). This commit updates the implementation of SlowMo to the Fairscale version. It also adds tests for SlowMo. Note: This PR is currently for review. It will be merged at a later date once SlowMo has been updated to Fairscale. SlowMo is being merged to Fairscale as part of [a PR](https://github.com/facebookresearch/fairscale/pull/378). So, once that PR is merged to Fairscale, this PR on Fairseq will be ready for merge ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3996 Reviewed By: dianaml0 Differential Revision: D32280163 Pulled By: vtantia fbshipit-source-id: 70c97b04a7cdc90ada7099375c2a31b0c978ba70 --- fairseq/dataclass/configs.py | 10 ++++++++-- fairseq/dataclass/constants.py | 2 +- fairseq/models/distributed_fairseq_model.py | 21 ++++++++++----------- fairseq/trainer.py | 18 +++++++----------- tests/gpu/test_binaries_gpu.py | 21 +++++++++++++++++++-- 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 80caa0f2d..289bb8896 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -338,8 +338,14 @@ class DistributedTrainingConfig(FairseqDataclass): "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" }, ) - slowmo_algorithm: str = field( - default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"} + slowmo_base_algorithm: str = field( + default="localsgd", + metadata={ + "help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer " + "to the documentation of 'slowmo_base_algorithm' parameter in " + "https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html " + "for more details" + }, ) localsgd_frequency: int = field( default=3, metadata={"help": "Local SGD allreduce frequency"} diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 4f159cfe9..7e5aef706 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -41,7 +41,7 @@ DDP_BACKEND_CHOICES = ChoiceEnum([ "legacy_ddp", "no_c10d", # alias for legacy_ddp "pytorch_ddp", - "slow_mo", + "slowmo", ]) DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 5eda22764..de8d6ac11 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -23,11 +23,11 @@ from fairseq.distributed import ( logger = logging.getLogger(__name__) -_GOSSIP_DISABLED = False +_SLOWMO_DDP_DISABLED = False try: - import gossip + from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel except ImportError: - _GOSSIP_DISABLED = True + _SLOWMO_DDP_DISABLED = True def DistributedFairseqModel(args, model, process_group, device): @@ -89,11 +89,11 @@ def DistributedFairseqModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) - elif args.ddp_backend == "slow_mo": - if _GOSSIP_DISABLED: + elif args.ddp_backend == "slowmo": + if _SLOWMO_DDP_DISABLED: raise ImportError( - "Cannot find gossip library. Please install from: " - "github.com/facebookresearch/stochastic_gradient_push" + "Cannot find SlowMoDistributedDataParallel. " + "Please install fairscale with: pip install fairscale" ) # The values of slowmo_momentum below were obtained by tuning on the @@ -107,15 +107,14 @@ def DistributedFairseqModel(args, model, process_group, device): args.slowmo_momentum = 0.5 else: args.slowmo_momentum = 0.6 + slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()] - wrapped_model = gossip.GossipDataParallel( + wrapped_model = SlowMoDistributedDataParallel( module=model.to(device), - device_ids=[args.device_id], - output_device=args.device_id, broadcast_buffers=args.broadcast_buffers, nprocs_per_node=args.nprocs_per_node, slowmo_momentum=args.slowmo_momentum, - localsgd=(args.slowmo_algorithm == "LocalSGD"), + slowmo_base_algorithm=slowmo_base_algorithm, localsgd_frequency=args.localsgd_frequency, ) # forward missing getattr and state_dict/load_state_dict to orig model diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e46ccfe0b..94130c8c3 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -855,7 +855,7 @@ class Trainer(object): if not self.tpu: if ( not self.cfg.optimization.use_bmuf - and self.cfg.distributed_training.ddp_backend != "slow_mo" + and self.cfg.distributed_training.ddp_backend != "slowmo" ): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): @@ -912,18 +912,14 @@ class Trainer(object): # Some distributed wrappers (e.g., SlowMo) need access to the optimizer # after the step - if hasattr(self.model, "perform_additional_optimizer_actions"): - if hasattr(self.optimizer, "fp32_params"): - self.model.perform_additional_optimizer_actions( - self.optimizer.optimizer, self.optimizer.fp32_params - ) - else: - self.model.perform_additional_optimizer_actions( - self.optimizer.optimizer - ) + if hasattr(self.model, "perform_slowmo"): + self.model.perform_slowmo( + self.optimizer.optimizer, + getattr(self.optimizer, "fp32_params", None) + ) logging_output = None - if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo": + if not overflow or self.cfg.distributed_training.ddp_backend == "slowmo": self.set_num_updates(self.get_num_updates() + 1) if self.cfg.ema.store_ema: diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index de8c24261..99eb7f558 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -32,15 +32,32 @@ class TestTranslationGPU(unittest.TestCase): logging.disable(logging.NOTSET) def test_fp16_multigpu(self): + self._test_multigpu( + "test_fp16", ["--fp16"] + ) + + def test_slowmo_multigpu(self): + self._test_multigpu( + "test_slowmo", + ["--ddp-backend", "slowmo", "--nprocs-per-node", "1"] + ) + + def test_slowmo_single_node_multigpu(self): + self._test_multigpu( + "test_slowmo_single_node", + ["--ddp-backend", "slowmo", "--nprocs-per-node", "2"] + ) + + def _test_multigpu(self, test_name, test_args): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory("test_fp16") as data_dir: + with tempfile.TemporaryDirectory(test_name) as data_dir: log = os.path.join(data_dir, "train.log") create_dummy_data(data_dir) preprocess_translation_data(data_dir) train_translation_model( data_dir, "fconv_iwslt_de_en", - ["--fp16", "--log-file", log], + test_args + ["--log-file", log], world_size=min(torch.cuda.device_count(), 2), ) generate_main(data_dir)