From cf17068aad5af49eb9a106e314f617b15438dc9e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sat, 4 May 2019 17:08:05 -0700 Subject: [PATCH] Initialize distributed using multiproc with all visible GPUs Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/695 Differential Revision: D15182613 Pulled By: myleott fbshipit-source-id: 4196346517d8e75ed9e903e9e01ab943d086f6f1 --- fairseq/distributed_utils.py | 24 ++++++++++++++++++++---- fairseq/options.py | 2 ++ train.py | 31 +++++++++++++++++++++++-------- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index f29513926..d640ce490 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -8,6 +8,7 @@ from collections import namedtuple import os import pickle +import socket import subprocess import warnings @@ -42,9 +43,20 @@ def infer_init_method(args): hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) args.distributed_init_method = 'tcp://{host}:{port}'.format( host=hostnames.split()[0].decode('utf-8'), - port=args.distributed_port) - args.distributed_rank = int(os.environ.get('SLURM_PROCID')) - args.device_id = int(os.environ.get('SLURM_LOCALID')) + port=args.distributed_port, + ) + nnodes = int(os.environ.get('SLURM_NNODES')) + ntasks_per_node = int(os.environ.get('SLURM_NTASKS_PER_NODE')) + if ntasks_per_node == 1: + assert args.distributed_world_size % nnodes == 0 + gpus_per_node = args.distributed_world_size // nnodes + node_id = int(os.environ.get('SLURM_NODEID')) + args.distributed_rank = node_id * gpus_per_node + else: + assert ntasks_per_node == args.distributed_world_size // nnodes + args.distributed_no_spawn = True + args.distributed_rank = int(os.environ.get('SLURM_PROCID')) + args.device_id = int(os.environ.get('SLURM_LOCALID')) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed @@ -60,13 +72,17 @@ def distributed_init(args): else: print('| distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method), flush=True) - dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) + print('| initialized host {} as rank {}'.format( + socket.gethostname(), args.distributed_rank), flush=True) + + # perform a dummy all-reduce to initialize the NCCL communicator + dist.all_reduce(torch.rand(1).cuda()) suppress_output(is_master(args)) diff --git a/fairseq/options.py b/fairseq/options.py index b669fe32f..86382cea2 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -266,6 +266,8 @@ def add_distributed_training_args(parser): help='port number (not required if using --distributed-init-method)') group.add_argument('--device-id', '--local_rank', default=0, type=int, help='which GPU to use (usually configured automatically)') + group.add_argument('--distributed-no-spawn', action='store_true', + help='do not spawn multiple processes even if multiple GPUs are visible') group.add_argument('--ddp-backend', default='c10d', type=str, choices=['c10d', 'no_c10d'], help='DistributedDataParallel backend') diff --git a/train.py b/train.py index 96ae251b0..37466b0d7 100644 --- a/train.py +++ b/train.py @@ -23,16 +23,21 @@ from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter -def main(args): +def main(args, init_distributed=False): utils.import_user_module(args) - if args.max_tokens is None: - args.max_tokens = 6000 - print(args) + assert args.max_tokens is not None or args.max_sentences is not None, \ + 'Must specify batch size either with --max-tokens or --max-sentences' + # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) + if init_distributed: + args.distributed_rank = distributed_utils.distributed_init(args) + + # Print args + print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) @@ -372,11 +377,11 @@ def load_dataset_splits(args, task): raise e -def distributed_main(i, args): +def distributed_main(i, args, start_rank=0): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = i - main(args) + args.distributed_rank = start_rank + i + main(args, init_distributed=True) def cli_main(): @@ -388,9 +393,19 @@ def cli_main(): if args.distributed_init_method is not None: # distributed training - distributed_main(args.device_id, args) + if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: + start_rank = args.distributed_rank + args.distributed_rank = None # assign automatically + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, start_rank), + nprocs=torch.cuda.device_count(), + ) + else: + distributed_main(args.device_id, args) elif args.distributed_world_size > 1: # fallback for single node with multiple GPUs + assert args.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id