mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-22 06:39:29 +03:00
Add optimizer history to checkpoints (and rearrange criterions slightly)
This commit is contained in:
parent
48631f7a3c
commit
e432459b37
@ -18,14 +18,14 @@ class CrossEntropyCriterion(FairseqCriterion):
|
||||
super().__init__()
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
def prepare(self, samples):
|
||||
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
|
||||
def grad_denom(self, samples):
|
||||
return sum(s['ntokens'] if s else 0 for s in samples)
|
||||
|
||||
def forward(self, net_output, sample):
|
||||
input = net_output.view(-1, net_output.size(-1))
|
||||
target = sample['target'].view(-1)
|
||||
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
|
||||
return loss / self.denom
|
||||
return loss
|
||||
|
||||
def aggregate(self, losses):
|
||||
return sum(losses) / math.log(2)
|
||||
|
@ -11,13 +11,17 @@ from torch.nn.modules.loss import _Loss
|
||||
|
||||
class FairseqCriterion(_Loss):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def prepare(self, samples):
|
||||
"""Prepare criterion for DataParallel training."""
|
||||
def grad_denom(self, samples):
|
||||
"""Gradient normalization term for DataParallel training."""
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare(self, model, sample):
|
||||
"""Apply criterion-specific modifications to the sample."""
|
||||
return sample
|
||||
|
||||
def forward(self, net_output, sample):
|
||||
"""Compute the loss for the given sample and network output."""
|
||||
raise NotImplementedError
|
||||
|
@ -49,14 +49,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
self.padding_idx = padding_idx
|
||||
self.weights = weights
|
||||
|
||||
def prepare(self, samples):
|
||||
self.denom = sum(s['ntokens'] if s else 0 for s in samples)
|
||||
def grad_denom(self, samples):
|
||||
return sum(s['ntokens'] if s else 0 for s in samples)
|
||||
|
||||
def forward(self, net_output, sample):
|
||||
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
|
||||
target = sample['target'].view(-1)
|
||||
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
|
||||
return loss / self.denom
|
||||
return loss
|
||||
|
||||
def aggregate(self, losses):
|
||||
return sum(losses) / math.log(2)
|
||||
|
@ -32,7 +32,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
(prefixed with `_async_`), which run on each process in parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, args, model, device_ids=None,
|
||||
def __init__(self, args, model, criterion, device_ids=None,
|
||||
multiprocessing_method='spawn'):
|
||||
if device_ids is None:
|
||||
device_ids = tuple(range(torch.cuda.device_count()))
|
||||
@ -42,16 +42,17 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
raise NotImplementedError('Training on CPU is not supported')
|
||||
model = model.share_memory()
|
||||
nccl_uid = nccl.get_unique_id()
|
||||
self.criterion = criterion
|
||||
|
||||
Future.gen_list([
|
||||
self.call_async(rank, '_async_init', args=args, model=model,
|
||||
nccl_uid=nccl_uid)
|
||||
criterion=criterion, nccl_uid=nccl_uid)
|
||||
for rank in range(self.num_replicas)
|
||||
])
|
||||
|
||||
self._grads_initialized = False
|
||||
|
||||
def _async_init(self, rank, device_id, args, model, nccl_uid):
|
||||
def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
|
||||
"""Initialize child processes."""
|
||||
self.args = args
|
||||
|
||||
@ -64,8 +65,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
# initialize NCCL
|
||||
nccl.initialize(self.num_replicas, nccl_uid, device_id)
|
||||
|
||||
# copy model to current device
|
||||
# copy model and criterion to current device
|
||||
self.model = model.cuda()
|
||||
self.criterion = criterion.cuda()
|
||||
|
||||
# initialize optimizer
|
||||
self.optimizer = NAG(self.model.parameters(), lr=self.args.lr,
|
||||
@ -104,8 +106,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
batch_offset=batch_offset, val_loss=val_loss).gen()
|
||||
|
||||
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss):
|
||||
utils.save_checkpoint(args, epoch, batch_offset, self.model,
|
||||
self.optimizer, self.lr_scheduler, val_loss)
|
||||
utils.save_checkpoint(args, epoch, batch_offset, self.model, self.criterion,
|
||||
self.optimizer, self.lr_scheduler, val_loss, self._optim_history)
|
||||
|
||||
def load_checkpoint(self, filename):
|
||||
"""Load a checkpoint into the model replicas in each process."""
|
||||
@ -117,13 +119,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
return epoch, batch_offset
|
||||
|
||||
def _async_load_checkpoint(self, rank, device_id, filename):
|
||||
return utils.load_checkpoint(filename, self.model, self.optimizer,
|
||||
self.lr_scheduler, cuda_device=device_id)
|
||||
epoch, batch_offset, self._optim_history = utils.load_checkpoint(
|
||||
filename, self.model, self.criterion, self.optimizer, self.lr_scheduler,
|
||||
cuda_device=device_id)
|
||||
return epoch, batch_offset
|
||||
|
||||
def train_step(self, samples, criterion):
|
||||
def train_step(self, samples):
|
||||
"""Do forward, backward and gradient step in parallel."""
|
||||
assert isinstance(criterion, FairseqCriterion)
|
||||
|
||||
# PyTorch initializes gradient buffers lazily, so the first
|
||||
# train step needs to send non-empty samples to all replicas
|
||||
replace_empty_samples = False
|
||||
@ -133,31 +135,36 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
|
||||
# scatter sample across GPUs
|
||||
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
|
||||
criterion.prepare(samples)
|
||||
|
||||
# calculate gradient normalization term
|
||||
grad_denom = self.criterion.grad_denom(samples)
|
||||
|
||||
# forward pass, backward pass and gradient step
|
||||
losses = [
|
||||
self.call_async(rank, '_async_train_step', criterion=criterion)
|
||||
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
|
||||
for rank in range(self.num_replicas)
|
||||
]
|
||||
|
||||
# aggregate losses and gradient norms
|
||||
losses, grad_norms = Future.gen_tuple_list(losses)
|
||||
loss = criterion.aggregate(losses)
|
||||
loss = self.criterion.aggregate(losses)
|
||||
|
||||
return loss, grad_norms[0]
|
||||
|
||||
def _async_train_step(self, rank, device_id, criterion):
|
||||
def _async_train_step(self, rank, device_id, grad_denom):
|
||||
self.model.train()
|
||||
|
||||
# zero grads even if net_input is None, since we will all-reduce them
|
||||
# zero grads even if self._sample is None, since we will all-reduce them
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# calculate loss and grads
|
||||
loss = 0
|
||||
if self._sample is not None:
|
||||
self._sample = self.criterion.prepare(self.model, self._sample)
|
||||
net_output = self.model(**self._sample['net_input'])
|
||||
loss_ = criterion(net_output, self._sample)
|
||||
loss_ = self.criterion(net_output, self._sample)
|
||||
if grad_denom is not None:
|
||||
loss_ /= grad_denom
|
||||
loss_.backward()
|
||||
loss = loss_.data[0]
|
||||
|
||||
@ -196,29 +203,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
|
||||
flat_grads.div_(coef)
|
||||
return norm
|
||||
|
||||
def valid_step(self, samples, criterion):
|
||||
def valid_step(self, samples):
|
||||
"""Do forward pass in parallel."""
|
||||
# scatter sample across GPUs
|
||||
self._scatter_samples(samples, volatile=True)
|
||||
criterion.prepare(samples)
|
||||
|
||||
# calculate gradient normalization term
|
||||
grad_denom = self.criterion.grad_denom(samples)
|
||||
|
||||
# forward pass
|
||||
losses = [
|
||||
self.call_async(rank, '_async_valid_step', criterion=criterion)
|
||||
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
|
||||
for rank in range(self.num_replicas)
|
||||
]
|
||||
|
||||
# aggregate losses
|
||||
loss = criterion.aggregate(Future.gen_list(losses))
|
||||
loss = self.criterion.aggregate(Future.gen_list(losses))
|
||||
|
||||
return loss
|
||||
|
||||
def _async_valid_step(self, rank, device_id, criterion):
|
||||
def _async_valid_step(self, rank, device_id, grad_denom):
|
||||
if self._sample is None:
|
||||
return 0
|
||||
self.model.eval()
|
||||
self._sample = self.criterion.prepare(self.model, self._sample)
|
||||
net_output = self.model(**self._sample['net_input'])
|
||||
loss = criterion(net_output, self._sample)
|
||||
loss = self.criterion(net_output, self._sample)
|
||||
if grad_denom is not None:
|
||||
loss /= grad_denom
|
||||
return loss.data[0]
|
||||
|
||||
def get_lr(self):
|
||||
|
@ -46,15 +46,23 @@ def torch_persistent_save(*args, **kwargs):
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, val_loss=None):
|
||||
def save_checkpoint(args, epoch, batch_offset, model, criterion, optimizer, lr_scheduler,
|
||||
val_loss=None, optim_history=None):
|
||||
if optim_history is None:
|
||||
optim_history = []
|
||||
state_dict = {
|
||||
'args': args,
|
||||
'epoch': epoch,
|
||||
'batch_offset': batch_offset,
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'best_loss': lr_scheduler.best,
|
||||
'val_loss': val_loss,
|
||||
'optimizer_history': optim_history + [
|
||||
{
|
||||
'criterion_name': criterion.__class__.__name__,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'best_loss': lr_scheduler.best,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if batch_offset == 0:
|
||||
@ -72,9 +80,9 @@ def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, v
|
||||
torch_persistent_save(state_dict, last_filename)
|
||||
|
||||
|
||||
def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
|
||||
def load_checkpoint(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
|
||||
if not os.path.exists(filename):
|
||||
return 1, 0
|
||||
return 1, 0, []
|
||||
if cuda_device is None:
|
||||
state = torch.load(filename)
|
||||
else:
|
||||
@ -82,16 +90,41 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
|
||||
filename,
|
||||
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
|
||||
)
|
||||
state = _upgrade_state_dict(state)
|
||||
|
||||
model.load_state_dict(state['model'])
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
lr_scheduler.best = state['best_loss']
|
||||
epoch = state['epoch'] + 1
|
||||
batch_offset = state['batch_offset']
|
||||
|
||||
# only load optimizer and lr_scheduler if they match with the checkpoint
|
||||
opt_str = ''
|
||||
optim_history = state['optimizer_history']
|
||||
last_optim = optim_history[-1]
|
||||
if last_optim['criterion_name'] == criterion.__class__.__name__:
|
||||
optimizer.load_state_dict(last_optim['optimizer'])
|
||||
lr_scheduler.best = last_optim['best_loss']
|
||||
opt_str = '; criterion: {}'.format(last_optim['criterion_name'])
|
||||
|
||||
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else ''
|
||||
print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch, gpu_str))
|
||||
return epoch, batch_offset
|
||||
print('| loaded checkpoint {} (epoch {}{}){}'.format(filename, epoch, opt_str, gpu_str))
|
||||
|
||||
return epoch, batch_offset, optim_history
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
# add optimizer_history
|
||||
if 'optimizer_history' not in state:
|
||||
state['optimizer_history'] = [
|
||||
{
|
||||
'criterion_name': criterions.CrossEntropyCriterion.__name__,
|
||||
'optimizer': state['optimizer'],
|
||||
'best_loss': state['best_loss'],
|
||||
},
|
||||
]
|
||||
del state['optimizer']
|
||||
del state['best_loss']
|
||||
return state
|
||||
|
||||
|
||||
def load_ensemble_for_inference(filenames, data_path, split):
|
||||
|
14
train.py
14
train.py
@ -68,7 +68,7 @@ def main():
|
||||
criterion = utils.build_criterion(args, dataset)
|
||||
|
||||
# Start multiprocessing
|
||||
trainer = MultiprocessingTrainer(args, model)
|
||||
trainer = MultiprocessingTrainer(args, model, criterion)
|
||||
|
||||
# Load the latest checkpoint if one is available
|
||||
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
|
||||
@ -81,11 +81,11 @@ def main():
|
||||
train_meter.start()
|
||||
while lr > args.min_lr and epoch <= max_epoch:
|
||||
# train for one epoch
|
||||
train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)
|
||||
train(args, epoch, batch_offset, trainer, dataset, num_gpus)
|
||||
|
||||
# evaluate on validate set
|
||||
for k, subset in enumerate(args.valid_subset.split(',')):
|
||||
val_loss = validate(args, epoch, trainer, criterion, dataset, subset, num_gpus)
|
||||
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
|
||||
if k == 0:
|
||||
if not args.no_save:
|
||||
# save checkpoint
|
||||
@ -102,7 +102,7 @@ def main():
|
||||
trainer.stop()
|
||||
|
||||
|
||||
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
|
||||
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
|
||||
"""Train the model for one epoch."""
|
||||
|
||||
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
|
||||
@ -121,7 +121,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
|
||||
lr = trainer.get_lr()
|
||||
with progress_bar(itr, desc, leave=False) as t:
|
||||
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
|
||||
loss, grad_norm = trainer.train_step(sample, criterion)
|
||||
loss, grad_norm = trainer.train_step(sample)
|
||||
|
||||
ntokens = sum(s['ntokens'] for s in sample)
|
||||
src_size = sum(s['src_tokens'].size(0) for s in sample)
|
||||
@ -160,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
|
||||
gnorm_meter.avg))
|
||||
|
||||
|
||||
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
|
||||
def validate(args, epoch, trainer, dataset, subset, ngpus):
|
||||
"""Evaluate the model on the validation set and return the average loss."""
|
||||
|
||||
itr = dataset.dataloader(subset, batch_size=None,
|
||||
@ -173,7 +173,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
|
||||
with progress_bar(itr, desc, leave=False) as t:
|
||||
for _, sample in data.skip_group_enumerator(t, ngpus):
|
||||
ntokens = sum(s['ntokens'] for s in sample)
|
||||
loss = trainer.valid_step(sample, criterion)
|
||||
loss = trainer.valid_step(sample)
|
||||
loss_meter.update(loss, ntokens)
|
||||
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user