Make multiprocessing_train.py work with multi-node setups

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/425

Differential Revision: D13558340

Pulled By: myleott

fbshipit-source-id: dff8c77027e821d8c80bfbd6a6ccce9ca1a44b78
This commit is contained in:
Myle Ott 2018-12-28 07:54:49 -08:00 committed by Facebook Github Bot
parent 58dd1862f6
commit 0cb87130e7
2 changed files with 11 additions and 9 deletions

View File

@ -7,7 +7,6 @@
# can be found in the PATENTS file in the same directory.
import os
import random
import signal
import torch
@ -17,12 +16,6 @@ from train import main as single_process_main
def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
@ -35,8 +28,9 @@ def main(args):
# Train with multiprocessing.
procs = []
for i in range(args.distributed_world_size):
args.distributed_rank = i
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()

View File

@ -13,6 +13,8 @@ import collections
import itertools
import os
import math
import random
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
@ -355,6 +357,12 @@ if __name__ == '__main__':
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1
multiprocessing_main(args)
else:
main(args)