diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 9730f5caa..f29513926 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -9,6 +9,7 @@ from collections import namedtuple import os import pickle import subprocess +import warnings import torch import torch.distributed as dist @@ -54,18 +55,22 @@ def distributed_init(args): if args.distributed_world_size == 1: raise ValueError('Cannot initialize distributed with distributed_world_size=1') - print('| distributed init (rank {}): {}'.format( - args.distributed_rank, args.distributed_init_method), flush=True) + if torch.distributed.is_initialized(): + warnings.warn('Distributed is already initialized, cannot initialize twice!') + else: + print('| distributed init (rank {}): {}'.format( + args.distributed_rank, args.distributed_init_method), flush=True) - dist.init_process_group( - backend=args.distributed_backend, - init_method=args.distributed_init_method, - world_size=args.distributed_world_size, - rank=args.distributed_rank, - ) + dist.init_process_group( + backend=args.distributed_backend, + init_method=args.distributed_init_method, + world_size=args.distributed_world_size, + rank=args.distributed_rank, + ) - suppress_output(is_master(args)) + suppress_output(is_master(args)) + args.distributed_rank = torch.distributed.get_rank() return args.distributed_rank diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index 5498645ed..038a8a5d9 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -6,8 +6,11 @@ # can be found in the PATENTS file in the same directory. import inspect +import socket + from torch.nn import parallel +from fairseq import distributed_utils from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel from . import BaseFairseqModel @@ -26,6 +29,9 @@ def DistributedFairseqModel(args, model): args (argparse.Namespace): fairseq args model (BaseFairseqModel): model to wrap """ + # rendezvous with other workers + args.distributed_rank = distributed_utils.distributed_init(args) + print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) # determine which DDP class to extend assert isinstance(model, BaseFairseqModel) diff --git a/train.py b/train.py index c745b0024..c963d7afb 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,7 @@ from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter -def main(args, init_distributed=False): +def main(args): utils.import_user_module(args) if args.max_tokens is None: @@ -82,12 +82,6 @@ def main(args, init_distributed=False): num_workers=args.num_workers, ) - # Initialize distributed training (after data loading) - if init_distributed: - import socket - args.distributed_rank = distributed_utils.distributed_init(args) - print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) - # Load the latest checkpoint if one is available if not load_checkpoint(args, trainer, epoch_itr): trainer.dummy_train_step([dummy_batch]) @@ -390,7 +384,7 @@ def distributed_main(i, args): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn args.distributed_rank = i - main(args, init_distributed=True) + main(args) def cli_main():