Initialize distributed using multiproc with all visible GPUs

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/695

Differential Revision: D15182613

Pulled By: myleott

fbshipit-source-id: 4196346517d8e75ed9e903e9e01ab943d086f6f1
This commit is contained in:
Myle Ott 2019-05-04 17:08:05 -07:00 committed by Facebook Github Bot
parent 96ac28d33d
commit cf17068aad
3 changed files with 45 additions and 12 deletions

View File

@ -8,6 +8,7 @@
from collections import namedtuple from collections import namedtuple
import os import os
import pickle import pickle
import socket
import subprocess import subprocess
import warnings import warnings
@ -42,9 +43,20 @@ def infer_init_method(args):
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format( args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'), host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port) port=args.distributed_port,
args.distributed_rank = int(os.environ.get('SLURM_PROCID')) )
args.device_id = int(os.environ.get('SLURM_LOCALID')) nnodes = int(os.environ.get('SLURM_NNODES'))
ntasks_per_node = int(os.environ.get('SLURM_NTASKS_PER_NODE'))
if ntasks_per_node == 1:
assert args.distributed_world_size % nnodes == 0
gpus_per_node = args.distributed_world_size // nnodes
node_id = int(os.environ.get('SLURM_NODEID'))
args.distributed_rank = node_id * gpus_per_node
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed except subprocess.CalledProcessError as e: # scontrol failed
raise e raise e
except FileNotFoundError: # Slurm is not installed except FileNotFoundError: # Slurm is not installed
@ -60,13 +72,17 @@ def distributed_init(args):
else: else:
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
dist.init_process_group( dist.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
init_method=args.distributed_init_method, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, world_size=args.distributed_world_size,
rank=args.distributed_rank, rank=args.distributed_rank,
) )
print('| initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank), flush=True)
# perform a dummy all-reduce to initialize the NCCL communicator
dist.all_reduce(torch.rand(1).cuda())
suppress_output(is_master(args)) suppress_output(is_master(args))

View File

@ -266,6 +266,8 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', '--local_rank', default=0, type=int, group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--distributed-no-spawn', action='store_true',
help='do not spawn multiple processes even if multiple GPUs are visible')
group.add_argument('--ddp-backend', default='c10d', type=str, group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'], choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend') help='DistributedDataParallel backend')

View File

@ -23,16 +23,21 @@ from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
def main(args): def main(args, init_distributed=False):
utils.import_user_module(args) utils.import_user_module(args)
if args.max_tokens is None: assert args.max_tokens is not None or args.max_sentences is not None, \
args.max_tokens = 6000 'Must specify batch size either with --max-tokens or --max-sentences'
print(args)
# Initialize CUDA and distributed training
if torch.cuda.is_available() and not args.cpu: if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id) torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if init_distributed:
args.distributed_rank = distributed_utils.distributed_init(args)
# Print args
print(args)
# Setup task, e.g., translation, language modeling, etc. # Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args) task = tasks.setup_task(args)
@ -372,11 +377,11 @@ def load_dataset_splits(args, task):
raise e raise e
def distributed_main(i, args): def distributed_main(i, args, start_rank=0):
args.device_id = i args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i args.distributed_rank = start_rank + i
main(args) main(args, init_distributed=True)
def cli_main(): def cli_main():
@ -388,9 +393,19 @@ def cli_main():
if args.distributed_init_method is not None: if args.distributed_init_method is not None:
# distributed training # distributed training
distributed_main(args.device_id, args) if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, start_rank),
nprocs=torch.cuda.device_count(),
)
else:
distributed_main(args.device_id, args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
# fallback for single node with multiple GPUs # fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id args.distributed_rank = None # set based on device id