Distributed training using torch.distributed

This version relies on torch.distributed for multi-gpu training. It supports both multi-node and single-node configurations.
This commit is contained in:
Sergey Edunov 2018-02-02 14:20:19 -08:00 committed by Myle Ott
parent f656c70379
commit c7033ef794
14 changed files with 745 additions and 531 deletions

39
distributed_train.py Normal file
View File

@ -0,0 +1,39 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import os
import socket
import subprocess
from train import main as single_process_main, parse_train_args
from fairseq.distributed_utils import distributed_init, supress_output
def main():
args = parse_train_args()
node_list = subprocess.check_output(['scontrol', 'show', 'hostnames',
os.environ.get("SLURM_JOB_NODELIST")])
args.distributed_master_host = node_list.split()[0].decode('utf-8')
if args.distributed_port == -1:
raise ValueError("--distributed-port must be specified for distributed training")
if args.distributed_init_method is None:
args.distributed_init_method = f'tcp://{args.distributed_master_host}:{args.distributed_port + 1}'
args.device_id = int(os.environ.get("SLURM_LOCALID"))
rank = int(os.environ.get("SLURM_NODEID")) * \
int(os.environ.get("SLURM_NTASKS_PER_NODE")) + \
args.device_id
print("Rank: {}, host: {}, local rank {} ".format(rank, socket.gethostname(), args.device_id))
args.distributed_rank = rank
distributed_init(args)
single_process_main(args)
if __name__ == '__main__':
main()

View File

@ -9,6 +9,7 @@
import contextlib
import itertools
import glob
import math
import numbers
import numpy as np
import os
@ -130,10 +131,10 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, num_workers=0, max_tokens=None,
def train_dataloader(self, split, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False):
sort_by_source_size=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
@ -141,38 +142,33 @@ class LanguageDatasets(object):
max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
dataset, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False):
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
res = []
def skip_group_enumerator(it, offset=0):
i = 0
for sample in it:
res.append(sample)
if len(res) >= ngpus:
if i >= offset:
yield res
res = []
i += 1
if len(res) > 0:
yield res
if i >= offset:
yield sample
i += 1
class sharded_iterator(object):
@ -191,7 +187,7 @@ class sharded_iterator(object):
yield v
class LanguagePairDataset(object):
class LanguagePairDataset(torch.utils.data.Dataset):
# padding constants
LEFT_PAD_SOURCE = True
@ -221,7 +217,8 @@ class LanguagePairDataset(object):
@staticmethod
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples],
@ -398,6 +395,18 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and

View File

@ -0,0 +1,36 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.distributed
def distributed_init(args):
if args.distributed_world_size == 1:
pass
print(f'Distributed init {args.distributed_init_method}')
if args.distributed_init_method.startswith("tcp://"):
torch.distributed.init_process_group(backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank)
else:
torch.distributed.init_process_group(backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
def supress_output():
import builtins as __builtin__
# Supress printing for all but 0th device.
# print(str, force=True) will force it print
_print = __builtin__.print
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
_print(*args, **kwargs)
__builtin__.print = print

View File

@ -1,167 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import os
import signal
import threading
from torch import multiprocessing
class MultiprocessingEventLoop(object):
"""Start a multiprocessing event loop."""
def __init__(self, device_ids=None, multiprocessing_method='spawn'):
super().__init__()
self.device_ids = tuple(device_ids)
self.num_replicas = len(device_ids)
self.rank = None
self._mp = multiprocessing.get_context(multiprocessing_method)
self._start_error_handler()
self._start_multiprocessing()
def call_async(self, rank, action, **kwargs):
"""Asynchronously call a function in each child process.
Call a function named `action` on the rank'th process and return
a Future with the result.
"""
def result_generator():
yield self.return_pipes[rank].recv()
assert not self.return_pipes[rank].poll(), \
'return pipe must be consumed before calling another function'
self.input_pipes[rank].send((action, kwargs))
return Future(result_generator())
def stop(self, interrupt_children=False):
"""Stop multiprocessing."""
for rank in range(self.num_replicas):
self.input_pipes[rank].close()
self.return_pipes[rank].close()
if interrupt_children:
# send KeyboardInterrupt to children
os.kill(self.procs[rank].pid, signal.SIGINT)
else:
self.procs[rank].join()
self.error_queue.put((None, None)) # poison pill
def _start_error_handler(self):
"""Error handler to catch exceptions in child processes."""
# create a thread to listen for errors in the child processes
self.error_queue = self._mp.SimpleQueue()
error_thread = threading.Thread(target=self._error_listener,
daemon=True)
error_thread.start()
# create signal handler that executes in the main process/thread and
# handles errors from child processes
signal.signal(signal.SIGUSR1, self._signal_handler)
def _error_listener(self):
"""A thread that listens for errors in the child processes.
Errors are handled in a signal handler in the main thread.
"""
(rank, original_trace) = self.error_queue.get()
if rank is None: # poison pill, return
return
# requeue error and switch to main thread for handling the error
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def _signal_handler(self, signal, frame):
"""Signal handler that handles errors from child processes.
This signal handler executes in the main/process thread.
"""
self.stop(interrupt_children=True)
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
def _start_multiprocessing(self):
"""Create child processes to run async event loop.
Each process reads input from a Pipe, performs some computation,
and returns its output to another Pipe.
"""
# create child processes
input_pipes = []
return_pipes = []
procs = []
for rank, id in enumerate(self.device_ids):
recv_input_pipe, send_input_pipe = self._mp.Pipe(duplex=False)
recv_return_pipe, send_return_pipe = self._mp.Pipe(duplex=False)
proc = self._mp.Process(
target=self._process_event_loop,
args=(rank, id, recv_input_pipe, send_return_pipe),
daemon=True)
proc.start()
input_pipes.append(send_input_pipe)
return_pipes.append(recv_return_pipe)
procs.append(proc)
self.input_pipes = input_pipes
self.return_pipes = return_pipes
self.procs = procs
def _process_event_loop(self, rank, device_id, input_pipe, return_pipe):
"""Event loop that runs in each child process.
Event loop:
- take an action from the input pipe
- call the corresponding function in this process
- put the return value in the return pipe
Any exceptions are put in the error queue.
"""
self.rank = rank
try:
# event loop
while True:
action, kwargs = input_pipe.recv()
action_fn = getattr(self, action)
return_pipe.send(action_fn(rank, device_id, **kwargs))
except EOFError:
# input pipe was closed, do nothing
pass
except KeyboardInterrupt:
# killed by parent, do nothing
pass
except Exception:
# propagate exception from child to parent process, keeping
# original traceback
import traceback
self.error_queue.put((rank, traceback.format_exc()))
finally:
# cleanup pipes
input_pipe.close()
return_pipe.close()
class Future(object):
"""A wrapper around a Python generator, with syntactic sugar."""
def __init__(self, generator):
self.generator = generator
def gen(self):
return next(self.generator)
@staticmethod
def gen_list(gens):
return [g.gen() for g in gens]
@staticmethod
def gen_tuple_list(gens):
list = [g.gen() for g in gens]
return zip(*list)

View File

@ -1,181 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
A modified version of torch.cuda.nccl.all_reduce for launching kernels on each
GPU separately.
"""
import ctypes
from ctypes.util import find_library
lib = None
nccl_2_0 = None
_uid = None
_rank = None
_num_devices = None
_comm = None
__all__ = ['all_reduce', 'initialize', 'get_unique_id']
# ncclDataType_t
nccl_types = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 1,
'torch.cuda.HalfTensor': 2,
'torch.cuda.FloatTensor': 3,
'torch.cuda.DoubleTensor': 4,
'torch.cuda.LongTensor': 5,
}
nccl_types_2_0 = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 2,
'torch.cuda.HalfTensor': 6,
'torch.cuda.FloatTensor': 7,
'torch.cuda.DoubleTensor': 8,
'torch.cuda.LongTensor': 4,
}
# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3
status_codes_2_0 = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Argument Error",
5: "Invalid Usage Error",
}
status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}
def _libnccl():
global nccl_2_0
global lib
global status_codes
global nccl_types
if lib is None:
lib = ctypes.pydll.LoadLibrary(find_library('nccl'))
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
if hasattr(lib, 'ncclGroupStart'):
nccl_2_0 = True
status_codes = status_codes_2_0
nccl_types = nccl_types_2_0
return lib
class NcclError(RuntimeError):
def __init__(self, status):
self.status = status
msg = '{0} ({1})'.format(status_codes.get(status), status)
super(NcclError, self).__init__(msg)
class NcclComm(ctypes.c_void_p):
def __del__(self):
lib.ncclCommDestroy(self)
class NcclUniqueId(ctypes.Structure):
_fields_ = [
('internal', ctypes.c_uint8 * 128)
]
def check_error(status):
if status != 0:
raise NcclError(status)
_uids = []
def get_unique_id():
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
uid = NcclUniqueId()
check_error(lib.ncclGetUniqueId(ctypes.byref(uid)))
_uids.append(uid) # Don't allow UIDs to be collected
return uid
def initialize(num_devices, uid, rank):
global _num_devices, _uid, _rank
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
_num_devices = num_devices
if rank != 0:
_uid = NcclUniqueId.from_buffer_copy(uid)
else:
_uid = uid
_rank = rank
def communicator():
global _comm
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
if _uid is None:
raise RuntimeError('NCCL not initialized')
if _comm is None:
comm = NcclComm()
check_error(lib.ncclCommInitRank(
ctypes.byref(comm),
ctypes.c_int(_num_devices),
_uid,
ctypes.c_int(_rank)))
_comm = comm
return _comm
def all_reduce(input, output=None, op=SUM, stream=None):
comm = communicator()
if output is None:
output = input
if stream is not None:
stream = stream.cuda_stream
data_type = nccl_types[input.type()]
check_error(lib.ncclAllReduce(
ctypes.c_void_p(input.data_ptr()),
ctypes.c_void_p(output.data_ptr()),
ctypes.c_size_t(input.numel()),
data_type,
op,
comm,
ctypes.c_void_p(stream)))
return output

View File

@ -7,6 +7,7 @@
#
import argparse
import torch
import torch.cuda
@ -69,8 +70,6 @@ def add_dataset_args(parser, train=False, gen=False):
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
@ -100,6 +99,25 @@ def add_dataset_args(parser, train=False, gen=False):
return group
def add_distributed_training_args(parser):
group = parser.add_argument_group('Multi-GPU training')
group.add_argument('--distributed-world-size', default=1, type=int, metavar='N',
help='total number of GPUs across all nodes, default: 1 GPU')
group.add_argument('--distributed-master-host', default='localhost', type=str,
help='Master host used for synchronizing stats across nodes')
group.add_argument('--distributed-port', default=-1, type=int,
help='TCP port number for synchronizing stats across nodes')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
help='distributed backend')
group.add_argument('--distributed-init-method', default=None, type=str,
help='Typically tcp://hostname:port that will be used to '
'establish initial connetion')
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',

View File

@ -52,27 +52,27 @@ class SequenceGenerator(object):
return self
def generate_batched_itr(
self,
data_itr,
beam_size=None,
maxlen_a=0.0,
maxlen_b=None,
cuda_device=None,
timer=None,
self,
data_itr,
beam_size=None,
maxlen_a=0.0,
maxlen_b=None,
cuda=False,
timer=None,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda_device: GPU on which to do generation.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda_device=cuda_device)
s = utils.make_variable(sample, volatile=True, cuda=cuda)
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:

127
fairseq/tcp_connector.py Normal file
View File

@ -0,0 +1,127 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import atexit
import pickle
import socket
import socketserver
import threading
import time
import faulthandler
class TcpConnector(object):
"""Synchronize data across multiple nodes over TCP."""
SOCKET_TIMEOUT=300
MESSAGE_QUEUE_SIZE = 10
def __init__(self, port, rank, world_size, master_host):
self.port = port
self.root = master_host
self.nhosts = world_size
self.host_idx = rank
self.server = None
self.current_message_id = 0
self.socket = None
faulthandler.enable(all_threads=True)
if rank == 0:
self._create_server()
def _create_server(self):
messages = {}
condition = threading.Condition()
nhosts = self.nhosts
class TCPHandler(socketserver.BaseRequestHandler):
def handle(self):
# Keep socket open forever
while True:
rcvd = TcpConnector.recv_msg(self.request)
if rcvd is None:
return
id, host_idx, data = rcvd
with condition:
if not id in messages:
messages[id] = [None] * nhosts
k = id - TcpConnector.MESSAGE_QUEUE_SIZE
if k in messages:
del messages[k]
messages[id][host_idx] = data
condition.wait_for(lambda : sum(1 for x in messages[id] if x is None) == 0,
timeout=TcpConnector.SOCKET_TIMEOUT)
condition.notify_all()
TcpConnector.send_msg(self.request, messages[id])
# HOST='' means running on interface 0.0.0.0 that seems to work for everyone
self.server = socketserver.ThreadingTCPServer(('', self.port), TCPHandler)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start()
print("Server is running on {}:{}".format(socket.gethostname(), self.port), flush=True)
@atexit.register
def _cleanup():
self.server.shutdown()
self.server.server_close()
self.server_thread.join()
def all_gather(self, message):
"""Gathers messages from all nodes into a list."""
for retry in range(8):
if retry > 0:
print("Retry {}, message {}".format(retry, message), flush=True)
time.sleep(2 ** retry)
try:
if self.socket is None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(TcpConnector.SOCKET_TIMEOUT)
self.socket.connect((self.root, self.port))
TcpConnector.send_msg(self.socket, (self.current_message_id, self.host_idx, message))
received = TcpConnector.recv_msg(self.socket)
self.current_message_id += 1
return received
except socket.timeout:
print("Socket timeout", flush=True)
TcpConnector.close(self.socket)
except ConnectionError:
print("Unable to connect to {}:{}, message_id {}, host_idx {}, message {}".format(
self.root, self.port, self.current_message_id, self.host_idx, message), flush=True)
TcpConnector.close(self.socket)
except Exception as e:
print("Unexpected exception {}".format(e))
break
raise Exception("Unable send the message to the root node")
@staticmethod
def close(socket):
if socket:
try:
socket.close()
except:
print("Unable to close socket")
@staticmethod
def send_msg(stream, message):
enc = pickle.dumps(message)
stream.sendall(len(enc).to_bytes(8, byteorder='big'))
stream.sendall(enc)
@staticmethod
def recv_msg(stream):
size = int.from_bytes(stream.recv(8), byteorder='big')
if size == 0:
print('Shutdown request received', flush=True)
return None
enc = stream.recv(size)
while len(enc) < size:
enc += stream.recv(size - len(enc))
data = pickle.loads(enc)
return data

View File

@ -10,16 +10,17 @@
Train a network on multiple GPUs using multiprocessing.
"""
from itertools import cycle, islice
import math
import torch
import torch.distributed
from fairseq import optim, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq import optim
from fairseq.optim import lr_scheduler
from fairseq import utils
from fairseq.tcp_connector import TcpConnector
class MultiprocessingTrainer(MultiprocessingEventLoop):
class Trainer(object):
"""Main class for multi-GPU training.
Each GPU has a full copy of the model and is assigned to its own Python
@ -31,35 +32,18 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
(prefixed with `_async_`), which run on each process in parallel.
"""
def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count()))
super().__init__(device_ids, multiprocessing_method)
def __init__(self, args, model, criterion):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory()
nccl_uid = nccl.get_unique_id()
self.criterion = criterion
Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model,
criterion=criterion, nccl_uid=nccl_uid)
for rank in range(self.num_replicas)
])
self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
"""Initialize child processes."""
self.args = args
self.rank = args.distributed_rank
self.world_size = args.distributed_world_size
# set CUDA device
torch.cuda.set_device(device_id)
# initialize NCCL
nccl.initialize(self.num_replicas, nccl_uid, device_id)
self._init_tcp_connector(args)
# copy model and criterion to current device
self.model = model.cuda()
@ -73,34 +57,26 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = 0
self._num_updates = 0
def _init_tcp_connector(self, args):
"""Discover rank of current host and hostnames of all other hosts."""
if args.distributed_world_size > 1:
self.tcp_connector = TcpConnector(
args.distributed_port, self.rank, self.world_size, args.distributed_master_host)
def get_model(self):
"""Get one of the model replicas."""
# just return the first model, since all replicas are the same
return self.call_async(0, '_async_get_model').gen()
def _async_get_model(self, rank, device_id):
return self.model
def save_checkpoint(self, filename, extra_state):
"""Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', filename=filename, extra_state=extra_state).gen()
def _async_save_checkpoint(self, rank, device_id, filename, extra_state):
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
if self.rank == 0: # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process."""
results = Future.gen_list([
self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas)
])
extra_state = results[0]
return extra_state
def _async_load_checkpoint(self, rank, device_id, filename):
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, cuda_device=device_id)
filename, self.model, cuda_device=torch.cuda.current_device())
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
@ -119,48 +95,38 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return extra_state
def set_seed(self, seed):
Future.gen_list([
self.call_async(rank, '_async_set_seed', seed=seed)
for rank in range(self.num_replicas)
])
def _async_set_seed(self, rank, device_id, seed):
torch.manual_seed(seed)
def train_step(self, samples):
def train_step(self, sample):
"""Do forward, backward and gradient step in parallel."""
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples = False
if not self._grads_initialized:
replace_empty_samples = True
self._grads_initialized = True
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
self.prepare_sample(sample, volatile=False)
# forward pass
sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])
sample_sizes, logging_outputs, ooms_fwd = self.forward()
if self.world_size > 1:
# synchronize logging outputs for multi-node training
sample_sizes, logging_outputs = zip(*list(self.tcp_connector.all_gather((sample_sizes, logging_outputs))))
else:
sample_sizes = [sample_sizes]
logging_outputs = [logging_outputs]
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms, ooms_bwd, lrs = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
grad_norm, ooms_bwd, lr = self.backward_and_opt(grad_denom=grad_denom)
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['lr'] = lrs[0]
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
logging_output['lr'] = lr
logging_output['gnorm'] = grad_norm # log the gradient norm
logging_output['oom'] = ooms_fwd + ooms_bwd
logging_output['ntokens'] = sum(log.get('ntokens', 0) for log in logging_outputs)
logging_output['nsentences'] = sum(log.get('nsentences', 0) for log in logging_outputs)
return logging_output
def _async_forward(self, rank, device_id, eval=False):
def forward(self, eval=False):
if eval:
self.model.eval()
else:
@ -168,14 +134,18 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.optimizer.zero_grad()
with utils.maybe_no_grad(eval):
sample_size, logging_output, oom = 0, {}, False
self.sample_size = 0
logging_output = {'ntokens': 0, 'nsentences': 0}
oom = False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
self.loss, self.sample_size, logging_output = self.criterion(self.model, self._sample)
logging_output['ntokens'] = self._sample['ntokens']
logging_output['nsentences'] = self._sample['target'].size(0)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
print('| WARNING: ran out of memory, skipping batch')
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
@ -183,9 +153,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else:
raise e
return sample_size, logging_output, oom
return self.sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom):
def backward_and_opt(self, grad_denom):
oom = False
if self.loss is not None:
try:
@ -193,7 +163,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
print('| WARNING: ran out of memory, skipping batch')
oom = True
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
@ -201,8 +171,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else:
raise e
# all-reduce grads and rescale by grad_denom
self._all_reduce_and_rescale_grads(grad_denom)
if self.world_size > 1:
# all-reduce grads and rescale by grad_denom
self._all_reduce_and_rescale_grads(grad_denom)
else:
for p in self.model.parameters():
if p.requires_grad:
p.grad.data.div_(grad_denom)
# clip grads
if self.args.clip_norm > 0:
@ -236,8 +211,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
buffer_t[offset:offset+numel].copy_(g.view(-1))
offset += numel
# all-reduce and rescale
nccl.all_reduce(buffer_t[:offset])
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
# copy all-reduced buffer back into grads
offset = 0
for g in buffer:
@ -250,7 +226,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
sz = g.numel() * g.element_size()
if sz > buffer_size:
# grad is bigger than buffer, all-reduce and rescale directly
nccl.all_reduce(g)
torch.distributed.all_reduce(g)
g.div_(grad_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
@ -264,64 +240,43 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples):
def valid_step(self, sample):
"""Do forward pass in parallel."""
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)
self.prepare_sample(sample, volatile=True)
# forward pass
_sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
])
assert sum(ooms_fwd) == 0
_sample_sizes, logging_outputs, ooms_fwd = self.forward(eval=True)
assert not ooms_fwd
if self.world_size > 1:
logging_outputs = list(self.tcp_connector.all_gather(logging_outputs))
else:
logging_outputs = [logging_outputs]
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['ntokens'] = sum(log.get('ntokens', 0) for log in logging_outputs)
logging_output['nsentences'] = sum(log.get('nsentences', 0) for log in logging_outputs)
return logging_output
def get_lr(self):
"""Get the current learning rate."""
return self.call_async(0, '_async_get_lr').gen()
def _async_get_lr(self, rank, device_id):
return self.optimizer.get_lr()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
lr = Future.gen_list([
self.call_async(rank, '_async_lr_step', epoch=epoch, val_loss=val_loss)
for rank in range(self.num_replicas)
])
return lr[0]
def _async_lr_step(self, rank, device_id, epoch, val_loss):
return self.lr_scheduler.step(epoch, val_loss)
def get_num_updates(self):
"""Get the number of parameters updates."""
return self.call_async(0, '_async_get_num_updates').gen()
def _async_get_num_updates(self, rank, device_id):
return self._num_updates
def _scatter_samples(self, samples, volatile=False, replace_empty_samples=False):
"""Split and distribute a sample across GPUs."""
if not replace_empty_samples:
# pad with None until its size is equal to the number of replicas
samples = samples + [None]*(self.num_replicas - len(samples))
else:
# pad by cycling through the given samples
samples = list(islice(cycle(samples), self.num_replicas))
Future.gen_list([
self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
for rank in range(self.num_replicas)
])
def _async_prepare_sample(self, rank, device_id, sample, volatile):
if sample is None:
def prepare_sample(self, sample, volatile):
if sample is None or len(sample) == 0:
self._sample = None
else:
if hasattr(torch.cuda, 'empty_cache'):
@ -330,4 +285,5 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.make_variable(sample, volatile=volatile, cuda_device=device_id)
self._sample = utils.make_variable(sample, volatile=volatile, cuda=True)

View File

@ -172,13 +172,16 @@ def volatile_variable(*args, **kwargs):
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None):
def make_variable(sample, volatile=False, cuda=False):
"""Wrap input tensors in Variable class."""
if len(sample) == 0:
return {}
def _make_variable(maybe_tensor):
if torch.is_tensor(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
if cuda is not None and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda(async=True)
if volatile:
return volatile_variable(maybe_tensor)
else:

View File

@ -75,7 +75,7 @@ def main():
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
cuda=use_cuda, timer=gen_timer)
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()

48
multiprocessing_train.py Normal file
View File

@ -0,0 +1,48 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import random
import torch
from torch import multiprocessing
from train import main as single_process_main, parse_train_args
from fairseq.distributed_utils import distributed_init, supress_output
def main():
args = parse_train_args()
if args.distributed_port == -1:
args.distributed_port = random.randint(10000, 20000)
args.distributed_world_size = torch.cuda.device_count()
args.distributed_master_host = 'localhost'
args.distributed_init_method = f'tcp://{args.distributed_master_host}:{args.distributed_port + 1}'
mp = multiprocessing.get_context("spawn")
procs = []
for i in range(args.distributed_world_size):
args.device_id = i
args.distributed_rank = i
procs.append(mp.Process(target=run, args=(args, )))
procs[i].start()
for p in procs:
p.join()
print(f'Process {p} complete')
def run(args):
distributed_init(args)
if args.device_id != 0:
supress_output()
single_process_main(args)
if __name__ == '__main__':
main()

309
sweep.py Executable file
View File

@ -0,0 +1,309 @@
#!/usr/bin/env python
import argparse
from collections import OrderedDict
import glob
import itertools
import os
import random
import shlex
import shutil
import subprocess
import uuid
def get_grid(args):
return [
hyperparam('--force-anneal', 50, save_dir_key=lambda val: f'fa{val}'),
hyperparam('--lr-scheduler', 'fixed'),
hyperparam('--max-epoch', 50),
hyperparam('--arch', 'fconv_wmt_en_de', save_dir_key=lambda val: val.split('_')[0]),
hyperparam('--optimizer', 'nag', save_dir_key=lambda val: val),
hyperparam('--lr', [2.5, 3.5, 5.0], save_dir_key=lambda val: f'lr{val}'),
hyperparam('--max-tokens', [5000], save_dir_key=lambda val: f'maxtok{val}'),
#hyperparam('--batch-size-warmup-epochs', [4], save_dir_key=lambda val: f'bszwarmup{val}'),
hyperparam('--clip-norm', 0.1, save_dir_key=lambda val: f'clip{val}'),
hyperparam('--dropout', 0.1, save_dir_key=lambda val: f'drop{val}'),
#hyperparam('--label-smoothing', 0.1, save_dir_key=lambda val: f'ls{val}'),
hyperparam('--log-format', 'json'),
hyperparam('--log-interval', '100'),
]
def postprocess_hyperparams(args, config):
"""Postprocess a given hyperparameter configuration."""
#if config['--seq-beam'].current_value <= 8:
# config['--max-tokens'].current_value = 400
#else:
# config['--max-tokens'].current_value = 300
pass
class hyperparam(object):
"""Base class for defining hyperparameters."""
def __init__(self, name, values=None, binary_flag=False, save_dir_key=None):
"""
Arguments:
- name : the name of the hyperparameter (e.g., `--dropout`)
- values : the set of values to sweep over (e.g., `[0.0, 0.1, 0.2]`)
- binary_flag : whether the hyperparameter uses a boolean flag (e.g., `--no-save`)
- save_dir_key : function that takes the hyperparameter value and returns the "key"
to be appended to the output directory name
"""
self.name = name
if values is None: # syntactic sugar for binary flags
self.values = [True]
self.binary_flag = True
else:
self.values = values if isinstance(values, list) else [values]
self.binary_flag = binary_flag
self.save_dir_key = save_dir_key
self.current_value = None
if len(self.values) > 1 and self.save_dir_key is None:
raise ValueError(f'{name} has more than one value but is missing a save_dir_key!')
def get_cli_args(self):
if self.binary_flag:
return [self.name] if self.current_value else []
else:
return [self.name, self.current_value]
def get_save_dir_key(self):
if self.save_dir_key is None:
return None
if self.binary_flag:
return self.save_dir_key(1) if self.current_value else None
return self.save_dir_key(self.current_value)
def main():
parser = argparse.ArgumentParser('Script for launching hyperparameter sweeps')
parser.add_argument('-d', '--data', required=True, help='path to data directory')
parser.add_argument('-p', '--prefix', required=True,
help='save checkpoints and logs in <checkpoints-dir>/<prefix>.<save_dir_key>')
parser.add_argument('-t', '--num-trials', required=True, type=int,
help='number of random hyperparam configurations to try (-1 for grid search)')
parser.add_argument('-g', '--num-gpus', type=int, required=True, help='number of GPUs per node')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes for distributed training')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--baseline-model', help='path to baseline model from which to resume training')
parser.add_argument('--checkpoints-dir', default=os.path.join('/checkpoint', os.environ['USER']),
help='save checkpoints and logs in <checkpoints-dir>/<prefix>.<save_dir_key>')
parser.add_argument('--resume-failed', action='store_true',
help='resume any runs that failed (assumes --num-trials and --seed are the same)')
parser.add_argument('--resume-finished', action='store_true',
help='force any runs that finished to begin again (uncommon)')
parser.add_argument('--dry-run', action='store_true',
help='output only a list of actions to perform without performing them')
parser.add_argument('--local', action='store_true',
help='run locally instead of submitting remote job')
args = parser.parse_args()
# compute all possible hyperparameter configurations
grid = get_grid(args)
grid_product = list(itertools.product(*[hp.values for hp in grid]))
# randomly shuffle configurations
random.seed(args.seed)
random.shuffle(grid_product)
for i, hp_values in enumerate(grid_product):
config = OrderedDict()
for hp, value in zip(grid, hp_values):
config[hp.name] = hp
config[hp.name].current_value = value
# postprocess hyperparams
postprocess_hyperparams(args, config)
# launch training
launch_train(args, config)
if i == args.num_trials - 1:
break
def launch_train(args, config):
def dry_run(msg):
if args.dry_run:
print(f'| dry-run: {msg}')
return args.dry_run
# compute save_dir
save_dir_key = '.'.join(filter(
lambda save_dir_key: save_dir_key is not None,
[hp.get_save_dir_key() for hp in config.values()]
))
num_total_gpus = args.num_nodes * args.num_gpus
save_dir = os.path.join(args.checkpoints_dir, f'{args.prefix}.{save_dir_key}.ngpu{num_total_gpus}')
# create save directory if it doesn't exist
if not os.path.exists(save_dir):
if not dry_run(f'create directory: {save_dir}'):
os.makedirs(save_dir)
# copy baseline model
checkpoint_last = os.path.join(save_dir, 'checkpoint_last.pt')
if args.baseline_model and not os.path.exists(checkpoint_last) and \
not dry_run(f'initialize with baseline model: {args.baseline_model}'):
if not os.path.exists(args.baseline_model):
raise FileNotFoundError(f'Cannot find baseline model: {args.baseline_model}')
shutil.copyfile(args.baseline_model, checkpoint_last)
# check for whether the run failed
if has_finished(save_dir, num_total_gpus):
if args.resume_finished:
dry_run(f'restart previously finished run: {save_dir}')
else:
print(f'skip finished run (override with --resume-finished): {save_dir}')
return
elif has_failed(save_dir, num_total_gpus):
if args.resume_failed:
dry_run(f'resume failed run: {save_dir}')
else:
print(f'skip failed run (override with --resume-failed): {save_dir}')
return
elif has_started(save_dir, num_total_gpus):
print(f'skip in progress run: {save_dir}')
return
# generate train command
train_cmd = ['python', 'distributed_train.py', args.data, '--save-dir', save_dir]
if num_total_gpus > 1:
train_cmd.extend(['--distributed-world-size', str(num_total_gpus)])
train_cmd.extend(['--distributed-port', str(get_random_port())])
for hp in config.values():
train_cmd.extend(map(str, hp.get_cli_args()))
if args.dry_run:
train_cmd_str = ' '.join(train_cmd)
dry_run(f'train command: {train_cmd_str}')
# start training
if args.local:
assert args.num_nodes == 1, 'distributed training cannot be combined with --local'
if not dry_run('start training locally'):
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(args.num_gpus)))
train_proc = subprocess.Popen(train_cmd, env=env)
train_proc.wait()
else:
if num_total_gpus == 1:
train_log = os.path.join(save_dir, 'train.log')
train_stderr = os.path.join(save_dir, 'train.stderr.%j') # %j = slurm job id
else:
# %t = slurm task id, %n = slurm node id
# If one log file per node is used, output has to be de-duplicated,
# e.g. by suppressing output of all but one tasks
# see supress_output.supress_output()
train_log = os.path.join(save_dir, 'train.log.node%t')
train_stderr = os.path.join(save_dir, 'train.stderr.node%t.%j')
# build command
excluded_hosts = os.environ.get('EXCLUDED_HOSTS', None)
srun_cmd = [
'srun',
'--output', train_log,
'--error', train_stderr,
] + train_cmd
sbatch_cmd = [
'sbatch',
'--job-name', f'{args.prefix}.{save_dir_key}',
'--gres', f'gpu:{args.num_gpus}',
'--nodes', str(args.num_nodes),
'--ntasks-per-node', str(args.num_gpus),
'--cpus-per-task', str(10),
'--open-mode', 'append',
'--no-requeue',
]
sbatch_cmd += ['-x', excluded_hosts] if excluded_hosts is not None else []
sbatch_cmd += ['--wrap', ' '.join(map(shlex.quote, srun_cmd))]
sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd))
if args.dry_run:
dry_run('start remote training')
dry_run(f'- log stdout to: {train_log}')
dry_run(f'- log stderr to: {train_stderr}')
dry_run(f'- run command: {sbatch_cmd_str}')
else:
for i in range(args.num_nodes):
with open(train_log.replace('%t', str(i)), 'a') as train_log_h:
# log most recent git commit
git_commit = subprocess.check_output(
'git log | head -n 1', shell=True, encoding='utf-8')
print(git_commit.rstrip(), file=train_log_h)
if i == 0:
print(f'running command: {sbatch_cmd_str}\n')
train_proc = subprocess.Popen(sbatch_cmd, stdout=train_log_h)
def has_finished(save_dir, num_nodes):
if num_nodes == 1:
train_log = os.path.join(save_dir, 'train.log')
else:
train_log = os.path.join(save_dir, 'train.log.node0')
if not os.path.exists(train_log):
return False
with open(train_log, 'r') as h:
lines = h.readlines()
if len(lines) == 0:
return False
if 'done training' in lines[-1]:
return True
return False
def has_failed(save_dir, num_nodes):
if not os.path.exists(save_dir):
return False
# find max job id
job_ids = []
for fn in os.listdir(save_dir):
if fn.startswith('train.stderr.'):
job_ids.append(int(fn.split('.')[-1]))
if len(job_ids) == 0:
return False
max_job_id = max(job_ids)
def _has_failed(stderr_fn):
with open(stderr_fn, 'r') as h:
for line in h:
if len(line.strip()) > 0:
# assume that any output in stderr indicates an error
return True
return False
if num_nodes == 1:
return _has_failed(os.path.join(save_dir, f'train.stderr.{max_job_id}'))
else:
for fn in os.listdir(save_dir):
if fn.startswith('train.stderr.') and fn.endswith(f'.{max_job_id}') \
and _has_failed(os.path.join(save_dir, fn)):
return True
return False
def has_started(save_dir, num_nodes):
if num_nodes == 1:
train_log = os.path.join(save_dir, 'train.log')
else:
train_log = os.path.join(save_dir, 'train.log.node0')
if not os.path.exists(train_log):
return False
return True
def get_random_port():
rng_state = random.getstate()
random.seed()
port = random.randint(8082, 20000)
random.setstate(rng_state)
return port
if __name__ == '__main__':
main()

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python3
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
@ -9,31 +9,31 @@
import collections
import os
import sys
import torch
import math
import torch.distributed
import torch.cuda
from fairseq import criterions, data, models, options, progress_bar
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.trainer import Trainer
def main():
parser = options.get_parser('Trainer')
options.add_dataset_args(parser, train=True)
options.add_optimization_args(parser)
options.add_checkpoint_args(parser)
options.add_model_args(parser)
args = options.parse_args_and_arch(parser)
print(args)
def main(args=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
if not args:
args = parse_train_args()
args.device_id = 0
os.makedirs(args.save_dir, exist_ok=True)
if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if args.num_gpus == 0:
raise NotImplementedError('Training on CPU is not supported')
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
torch.manual_seed(args.seed)
# Load dataset
@ -46,14 +46,13 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
args.num_gpus, args.max_tokens, args.max_sentences))
torch.cuda.set_device(args.device_id)
# Build model and criterion
model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
@ -68,8 +67,12 @@ def main():
)
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
gpus_str = '{} GPUs'.format(args.distributed_world_size)
print('| using {} (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
gpus_str, args.max_tokens, args.max_sentences))
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion)
trainer = Trainer(args, model, criterion)
# Load the latest checkpoint if one is available
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
@ -79,13 +82,12 @@ def main():
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
lr = trainer.lr_step(epoch)
trainer.lr_step(epoch)
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
val_loss = None
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
@ -96,7 +98,8 @@ def main():
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
val_loss = validate(args, epoch, trainer, dataset,
max_positions_valid, subset)
if k == 0:
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(epoch, val_loss)
@ -108,10 +111,18 @@ def main():
epoch += 1
batch_offset = 0
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
# Stop multiprocessing
trainer.stop()
def parse_train_args():
parser = options.get_parser('Trainer')
options.add_dataset_args(parser, train=True)
options.add_optimization_args(parser)
options.add_checkpoint_args(parser)
options.add_model_args(parser)
options.add_distributed_training_args(parser)
args = options.parse_args_and_arch(parser)
return args
def get_perplexity(loss):
@ -129,11 +140,13 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
trainer.set_seed(seed)
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers,
args.train_subset,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
sort_by_source_size=(epoch <= args.curriculum),
shard_id=args.distributed_rank, num_shards=args.distributed_world_size,
)
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
@ -146,22 +159,23 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
num_updates = trainer.get_num_updates()
with progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') as t:
for i, sample in enumerate(
data.skip_group_enumerator(t, args.num_gpus, batch_offset),
data.skip_group_enumerator(t, batch_offset),
start=num_updates,
):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
lr = loss_dict['lr']
ntokens = loss_dict['ntokens']
nsentences = loss_dict['nsentences']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
del loss_dict['lr']
ntokens = sum(s['ntokens'] for s in sample)
del loss_dict['ntokens']
del loss_dict['nsentences']
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
wpb_meter.update(ntokens)
@ -187,7 +201,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
# ignore the first mini-batch in words-per-second calculation
wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
save_checkpoint(trainer, args, epoch, i + 1, 0)
t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
@ -240,6 +254,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank, num_shards=args.distributed_world_size,
)
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
@ -247,11 +262,13 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
prefix = 'valid on \'{}\' subset'.format(subset)
with progress_bar.build_progress_bar(args, itr, epoch, prefix, no_progress_bar='simple') as t:
for sample in data.skip_group_enumerator(t, args.num_gpus):
for sample in data.skip_group_enumerator(t):
loss_dict = trainer.valid_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
ntokens = loss_dict['ntokens']
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
del loss_dict['ntokens']
del loss_dict['nsentences']
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']