Update implemention of SlowMo to its implementation in Fairscale (#3996)

Summary:
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
SlowMo is being moved to [Fairscale](https://fairscale.readthedocs.io/en/latest/). This commit updates the implementation of SlowMo to the Fairscale version. It also adds tests for SlowMo.
Note: This PR is currently for review. It will be merged at a later date once SlowMo has been updated to Fairscale. SlowMo is being merged to Fairscale as part of [a PR](https://github.com/facebookresearch/fairscale/pull/378). So, once that PR is merged to Fairscale, this PR on Fairseq will be ready for merge

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3996

Reviewed By: dianaml0

Differential Revision: D32280163

Pulled By: vtantia

fbshipit-source-id: 70c97b04a7cdc90ada7099375c2a31b0c978ba70
This commit is contained in:
Vinayak Tantia 2021-11-09 09:41:21 -08:00 committed by Facebook GitHub Bot
parent 0b21875e45
commit 3a5838c320
5 changed files with 45 additions and 27 deletions

View File

@ -338,8 +338,14 @@ class DistributedTrainingConfig(FairseqDataclass):
"0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
},
)
slowmo_algorithm: str = field(
default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"}
slowmo_base_algorithm: str = field(
default="localsgd",
metadata={
"help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer "
"to the documentation of 'slowmo_base_algorithm' parameter in "
"https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html "
"for more details"
},
)
localsgd_frequency: int = field(
default=3, metadata={"help": "Local SGD allreduce frequency"}

View File

@ -41,7 +41,7 @@ DDP_BACKEND_CHOICES = ChoiceEnum([
"legacy_ddp",
"no_c10d", # alias for legacy_ddp
"pytorch_ddp",
"slow_mo",
"slowmo",
])
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])

View File

@ -23,11 +23,11 @@ from fairseq.distributed import (
logger = logging.getLogger(__name__)
_GOSSIP_DISABLED = False
_SLOWMO_DDP_DISABLED = False
try:
import gossip
from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel
except ImportError:
_GOSSIP_DISABLED = True
_SLOWMO_DDP_DISABLED = True
def DistributedFairseqModel(args, model, process_group, device):
@ -89,11 +89,11 @@ def DistributedFairseqModel(args, model, process_group, device):
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "slow_mo":
if _GOSSIP_DISABLED:
elif args.ddp_backend == "slowmo":
if _SLOWMO_DDP_DISABLED:
raise ImportError(
"Cannot find gossip library. Please install from: "
"github.com/facebookresearch/stochastic_gradient_push"
"Cannot find SlowMoDistributedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
# The values of slowmo_momentum below were obtained by tuning on the
@ -107,15 +107,14 @@ def DistributedFairseqModel(args, model, process_group, device):
args.slowmo_momentum = 0.5
else:
args.slowmo_momentum = 0.6
slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()]
wrapped_model = gossip.GossipDataParallel(
wrapped_model = SlowMoDistributedDataParallel(
module=model.to(device),
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
nprocs_per_node=args.nprocs_per_node,
slowmo_momentum=args.slowmo_momentum,
localsgd=(args.slowmo_algorithm == "LocalSGD"),
slowmo_base_algorithm=slowmo_base_algorithm,
localsgd_frequency=args.localsgd_frequency,
)
# forward missing getattr and state_dict/load_state_dict to orig model

View File

@ -855,7 +855,7 @@ class Trainer(object):
if not self.tpu:
if (
not self.cfg.optimization.use_bmuf
and self.cfg.distributed_training.ddp_backend != "slow_mo"
and self.cfg.distributed_training.ddp_backend != "slowmo"
):
self._check_grad_norms(grad_norm)
if not torch.isfinite(grad_norm).all():
@ -912,18 +912,14 @@ class Trainer(object):
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer
# after the step
if hasattr(self.model, "perform_additional_optimizer_actions"):
if hasattr(self.optimizer, "fp32_params"):
self.model.perform_additional_optimizer_actions(
self.optimizer.optimizer, self.optimizer.fp32_params
)
else:
self.model.perform_additional_optimizer_actions(
self.optimizer.optimizer
)
if hasattr(self.model, "perform_slowmo"):
self.model.perform_slowmo(
self.optimizer.optimizer,
getattr(self.optimizer, "fp32_params", None)
)
logging_output = None
if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
if not overflow or self.cfg.distributed_training.ddp_backend == "slowmo":
self.set_num_updates(self.get_num_updates() + 1)
if self.cfg.ema.store_ema:

View File

@ -32,15 +32,32 @@ class TestTranslationGPU(unittest.TestCase):
logging.disable(logging.NOTSET)
def test_fp16_multigpu(self):
self._test_multigpu(
"test_fp16", ["--fp16"]
)
def test_slowmo_multigpu(self):
self._test_multigpu(
"test_slowmo",
["--ddp-backend", "slowmo", "--nprocs-per-node", "1"]
)
def test_slowmo_single_node_multigpu(self):
self._test_multigpu(
"test_slowmo_single_node",
["--ddp-backend", "slowmo", "--nprocs-per-node", "2"]
)
def _test_multigpu(self, test_name, test_args):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
with tempfile.TemporaryDirectory(test_name) as data_dir:
log = os.path.join(data_dir, "train.log")
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"fconv_iwslt_de_en",
["--fp16", "--log-file", log],
test_args + ["--log-file", log],
world_size=min(torch.cuda.device_count(), 2),
)
generate_main(data_dir)