Parameterized criterions (#808)

Summary:
Support criterion with parameters, such as AutoSegmentationCriterion (ASG) used in wav2letter which has a transition matrix parameter. This is needed to integrate wav2letter's ASG into PySpeech.

With this diff, parameters in criterions will be:
(1) updated by optimizers, with a configurable learning rate
(2) saved and loaded from checkpoints, preserving backward compatibility for criterions without parameters
(3) synchronized across nodes in distributed training.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/808

Reviewed By: jcai1

Differential Revision: D16934097

Pulled By: okhonko

fbshipit-source-id: 121ec9382459385c6f9cbef3a8274bec1a434038
This commit is contained in:
Jeff Cai 2019-08-21 13:41:41 -07:00 committed by Facebook Github Bot
parent a2f5361d70
commit ba5f829f64
15 changed files with 79 additions and 39 deletions

View File

@ -222,6 +222,7 @@ def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
from fairseq import utils
if optim_history is None:
optim_history = []
if extra_state is None:
@ -239,6 +240,8 @@ def save_state(
],
'extra_state': extra_state,
}
if utils.has_parameters(criterion):
state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename)

View File

@ -5,7 +5,7 @@
import inspect
from torch.nn import parallel
import torch.nn as nn
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from fairseq.models import BaseFairseqModel
@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model):
model (BaseFairseqModel): model to wrap
"""
# determine which DDP class to extend
assert isinstance(model, BaseFairseqModel)
assert isinstance(model, nn.Module)
if args.ddp_backend == 'c10d':
ddp_class = parallel.DistributedDataParallel
ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict(
module=model,
device_ids=[args.device_id],

View File

@ -19,18 +19,13 @@ __all__ = [
]
_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
'--optimizer',
base_class=FairseqOptimizer,
default='nag',
)
def build_optimizer(args, params, *extra_args, **extra_kwargs):
params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(args, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):

View File

@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adadelta')
class Adadelta(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
@staticmethod

View File

@ -13,7 +13,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adafactor')
class FairseqAdafactor(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = Adafactor(params, **self.optimizer_config)
@staticmethod

View File

@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@staticmethod

View File

@ -16,7 +16,7 @@ from . import FairseqOptimizer, register_optimizer
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
if torch.cuda.is_available():
try:
from apex.optimizers import FusedAdam as _FusedAdam # noqa

View File

@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adamax')
class FairseqAdamax(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config)
@staticmethod

View File

@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer):
model-update filtering
"""
def __init__(self, args, params, optimizer):
def __init__(self, args, optimizer):
super().__init__(args, params)
super().__init__(args)
self._optimizer = optimizer
self.params = params
self._num_updates = 0
self.sync_iter = self.args.global_sync_iter
self.block_momentum = self.args.block_momentum

View File

@ -10,10 +10,9 @@ import torch
class FairseqOptimizer(object):
def __init__(self, args, params):
def __init__(self, args):
super().__init__()
self.args = args
self.params = list(params)
@staticmethod
def add_args(parser):
@ -39,6 +38,13 @@ class FairseqOptimizer(object):
"""
raise NotImplementedError
@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
yield p
def __getstate__(self):
return self._optimizer.__getstate__()
@ -93,9 +99,8 @@ class FairseqOptimizer(object):
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
for p in self.params:
p.grad = None
self.optimizer.zero_grad()
@property

View File

@ -60,7 +60,8 @@ class FP16Optimizer(optim.FairseqOptimizer):
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
@ -149,7 +150,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
for p in self.fp16_params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
@ -196,7 +197,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
for p in self.fp16_params:
if not p.requires_grad:
continue
numel = p.data.numel()
@ -205,7 +206,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for p in self.params:
for p in self.fp16_params:
p.grad = None
self._needs_sync = False
@ -232,7 +233,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
)
super().__init__(args, params)
super().__init__(args)
self.wrapped_optimizer = optimizer
if getattr(args, 'fp16_scale_window', None) is None:

View File

@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('nag')
class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = NAG(params, **self.optimizer_config)
@staticmethod

View File

@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('sgd')
class SGD(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
super().__init__(args)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod

View File

@ -36,13 +36,14 @@ class Trainer(object):
self.task = task
# copy model and criterion to current device
self.criterion = criterion
self._criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
if self.cuda:
self.criterion = self.criterion.cuda()
self._criterion = self._criterion.cuda()
self._model = self._model.cuda()
self._dummy_batch = dummy_batch
@ -53,6 +54,7 @@ class Trainer(object):
self._optim_history = None
self._optimizer = None
self._prev_grad_norm = None
self._wrapped_criterion = None
self._wrapped_model = None
self.init_meters(args)
@ -75,6 +77,21 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property
def criterion(self):
if self._wrapped_criterion is None:
if (
utils.has_parameters(self._criterion)
and self.args.distributed_world_size > 1
and not self.args.use_bmuf
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.args, self._criterion
)
else:
self._wrapped_criterion = self._criterion
return self._wrapped_criterion
@property
def model(self):
if self._wrapped_model is None:
@ -99,7 +116,13 @@ class Trainer(object):
return self._lr_scheduler
def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.criterion.parameters()),
)
)
if self.args.fp16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
@ -114,7 +137,7 @@ class Trainer(object):
self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer)
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
@ -126,7 +149,7 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion,
filename, self.args, self.get_model().state_dict(), self.get_criterion(),
self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state,
)
@ -148,6 +171,8 @@ class Trainer(object):
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint {}; '
@ -164,7 +189,7 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
@ -322,9 +347,9 @@ class Trainer(object):
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.criterion)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
@ -424,10 +449,10 @@ class Trainer(object):
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_output, self.criterion
logging_output, self.get_criterion()
)
sample_size = self.task.grad_denom(
sample_size, self.criterion
sample_size, self.get_criterion()
)
# update meters for validation
@ -477,6 +502,10 @@ class Trainer(object):
"""Get the (non-wrapped) model instance."""
return self._model
def get_criterion(self):
"""Get the (non-wrapped) criterion instance."""
return self._criterion
def get_meter(self, name):
"""Get a specific meter by name."""
if name not in self.meters:

View File

@ -351,3 +351,11 @@ def eval(model):
model.eval()
yield
model.train(is_training)
def has_parameters(module):
try:
next(module.parameters())
return True
except StopIteration:
return False