Move distributed_init into DistributedFairseqModel (#687)

Summary:
This should make rendezvous happen as lazily as possible.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/687

Differential Revision: D15151145

Pulled By: myleott

fbshipit-source-id: d70816a85414c5d509a6b12e2b339b4736db2c88
This commit is contained in:
Myle Ott 2019-05-02 08:02:46 -07:00 committed by Facebook Github Bot
parent fb18be00f7
commit 34726d5612
3 changed files with 22 additions and 17 deletions

View File

@ -9,6 +9,7 @@ from collections import namedtuple
import os
import pickle
import subprocess
import warnings
import torch
import torch.distributed as dist
@ -54,18 +55,22 @@ def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if torch.distributed.is_initialized():
warnings.warn('Distributed is already initialized, cannot initialize twice!')
else:
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
suppress_output(is_master(args))
suppress_output(is_master(args))
args.distributed_rank = torch.distributed.get_rank()
return args.distributed_rank

View File

@ -6,8 +6,11 @@
# can be found in the PATENTS file in the same directory.
import inspect
import socket
from torch.nn import parallel
from fairseq import distributed_utils
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from . import BaseFairseqModel
@ -26,6 +29,9 @@ def DistributedFairseqModel(args, model):
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
# rendezvous with other workers
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
# determine which DDP class to extend
assert isinstance(model, BaseFairseqModel)

View File

@ -23,7 +23,7 @@ from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
def main(args, init_distributed=False):
def main(args):
utils.import_user_module(args)
if args.max_tokens is None:
@ -82,12 +82,6 @@ def main(args, init_distributed=False):
num_workers=args.num_workers,
)
# Initialize distributed training (after data loading)
if init_distributed:
import socket
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
trainer.dummy_train_step([dummy_batch])
@ -390,7 +384,7 @@ def distributed_main(i, args):
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
main(args, init_distributed=True)
main(args)
def cli_main():