diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index c81213678..4a8855d54 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -222,6 +222,7 @@ def save_state( filename, args, model_state_dict, criterion, optimizer, lr_scheduler, num_updates, optim_history=None, extra_state=None, ): + from fairseq import utils if optim_history is None: optim_history = [] if extra_state is None: @@ -239,6 +240,8 @@ def save_state( ], 'extra_state': extra_state, } + if utils.has_parameters(criterion): + state_dict['criterion'] = criterion.state_dict() if not args.no_save_optimizer_state: state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) torch_persistent_save(state_dict, filename) diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index e858717d9..dbf384367 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -5,7 +5,7 @@ import inspect -from torch.nn import parallel +import torch.nn as nn from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel from fairseq.models import BaseFairseqModel @@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model): model (BaseFairseqModel): model to wrap """ # determine which DDP class to extend - assert isinstance(model, BaseFairseqModel) + assert isinstance(model, nn.Module) if args.ddp_backend == 'c10d': - ddp_class = parallel.DistributedDataParallel + ddp_class = nn.parallel.DistributedDataParallel init_kwargs = dict( module=model, device_ids=[args.device_id], diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 268291be7..2b8334d8c 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -19,18 +19,13 @@ __all__ = [ ] -_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( +build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( '--optimizer', base_class=FairseqOptimizer, default='nag', ) -def build_optimizer(args, params, *extra_args, **extra_kwargs): - params = list(filter(lambda p: p.requires_grad, params)) - return _build_optimizer(args, params, *extra_args, **extra_kwargs) - - # automatically import any Python files in the optim/ directory for file in os.listdir(os.path.dirname(__file__)): if file.endswith('.py') and not file.startswith('_'): diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py index 27079a402..8a9d54cc8 100644 --- a/fairseq/optim/adadelta.py +++ b/fairseq/optim/adadelta.py @@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('adadelta') class Adadelta(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 1c026244b..680ac371b 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -13,7 +13,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('adafactor') class FairseqAdafactor(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = Adafactor(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index 15b3a1c25..5dead3b25 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('adagrad') class Adagrad(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 0df118206..51a282380 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -16,7 +16,7 @@ from . import FairseqOptimizer, register_optimizer class FairseqAdam(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) if torch.cuda.is_available(): try: from apex.optimizers import FusedAdam as _FusedAdam # noqa diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py index 2a2e7698a..a22f7cda8 100644 --- a/fairseq/optim/adamax.py +++ b/fairseq/optim/adamax.py @@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('adamax') class FairseqAdamax(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = Adamax(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 756374d56..651fe7e60 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer): model-update filtering """ - def __init__(self, args, params, optimizer): + def __init__(self, args, optimizer): - super().__init__(args, params) + super().__init__(args) self._optimizer = optimizer - self.params = params self._num_updates = 0 self.sync_iter = self.args.global_sync_iter self.block_momentum = self.args.block_momentum diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 58bc7fc2d..030b1fe4a 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -10,10 +10,9 @@ import torch class FairseqOptimizer(object): - def __init__(self, args, params): + def __init__(self, args): super().__init__() self.args = args - self.params = list(params) @staticmethod def add_args(parser): @@ -39,6 +38,13 @@ class FairseqOptimizer(object): """ raise NotImplementedError + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.optimizer.param_groups: + for p in param_group['params']: + yield p + def __getstate__(self): return self._optimizer.__getstate__() @@ -93,9 +99,8 @@ class FairseqOptimizer(object): def zero_grad(self): """Clears the gradients of all optimized parameters.""" - for group in self.optimizer.param_groups: - for p in group['params']: - p.grad = None + for p in self.params: + p.grad = None self.optimizer.zero_grad() @property diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b3ae1ef49..194e0f4f4 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -60,7 +60,8 @@ class FP16Optimizer(optim.FairseqOptimizer): """ def __init__(self, args, params, fp32_optimizer, fp32_params): - super().__init__(args, params) + super().__init__(args) + self.fp16_params = params self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params @@ -149,7 +150,7 @@ class FP16Optimizer(optim.FairseqOptimizer): if self._needs_sync: # copy FP16 grads to FP32 offset = 0 - for p in self.params: + for p in self.fp16_params: if not p.requires_grad: continue grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) @@ -196,7 +197,7 @@ class FP16Optimizer(optim.FairseqOptimizer): # copy FP32 params back into FP16 model offset = 0 - for p in self.params: + for p in self.fp16_params: if not p.requires_grad: continue numel = p.data.numel() @@ -205,7 +206,7 @@ class FP16Optimizer(optim.FairseqOptimizer): def zero_grad(self): """Clears the gradients of all optimized parameters.""" - for p in self.params: + for p in self.fp16_params: p.grad = None self._needs_sync = False @@ -232,7 +233,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): 'Unsupported optimizer: {}'.format(optimizer.__class__.__name__) ) - super().__init__(args, params) + super().__init__(args) self.wrapped_optimizer = optimizer if getattr(args, 'fp16_scale_window', None) is None: diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index c916b6fad..25c260987 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('nag') class FairseqNAG(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = NAG(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py index c34b9590d..0efb283c6 100644 --- a/fairseq/optim/sgd.py +++ b/fairseq/optim/sgd.py @@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer @register_optimizer('sgd') class SGD(FairseqOptimizer): def __init__(self, args, params): - super().__init__(args, params) + super().__init__(args) self._optimizer = torch.optim.SGD(params, **self.optimizer_config) @staticmethod diff --git a/fairseq/trainer.py b/fairseq/trainer.py index ce0e74dc9..58448c83a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -36,13 +36,14 @@ class Trainer(object): self.task = task # copy model and criterion to current device - self.criterion = criterion + self._criterion = criterion self._model = model self.cuda = torch.cuda.is_available() and not args.cpu if args.fp16: + self._criterion = self._criterion.half() self._model = self._model.half() if self.cuda: - self.criterion = self.criterion.cuda() + self._criterion = self._criterion.cuda() self._model = self._model.cuda() self._dummy_batch = dummy_batch @@ -53,6 +54,7 @@ class Trainer(object): self._optim_history = None self._optimizer = None self._prev_grad_norm = None + self._wrapped_criterion = None self._wrapped_model = None self.init_meters(args) @@ -75,6 +77,21 @@ class Trainer(object): self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds + @property + def criterion(self): + if self._wrapped_criterion is None: + if ( + utils.has_parameters(self._criterion) + and self.args.distributed_world_size > 1 + and not self.args.use_bmuf + ): + self._wrapped_criterion = models.DistributedFairseqModel( + self.args, self._criterion + ) + else: + self._wrapped_criterion = self._criterion + return self._wrapped_criterion + @property def model(self): if self._wrapped_model is None: @@ -99,7 +116,13 @@ class Trainer(object): return self._lr_scheduler def _build_optimizer(self): - params = list(filter(lambda p: p.requires_grad, self.model.parameters())) + params = list( + filter( + lambda p: p.requires_grad, + chain(self.model.parameters(), self.criterion.parameters()), + ) + ) + if self.args.fp16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: print('| WARNING: your device does NOT support faster training with --fp16, ' @@ -114,7 +137,7 @@ class Trainer(object): self._optimizer = optim.build_optimizer(self.args, params) if self.args.use_bmuf: - self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer) + self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. @@ -126,7 +149,7 @@ class Trainer(object): if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters checkpoint_utils.save_state( - filename, self.args, self.get_model().state_dict(), self.criterion, + filename, self.args, self.get_model().state_dict(), self.get_criterion(), self.optimizer, self.lr_scheduler, self.get_num_updates(), self._optim_history, extra_state, ) @@ -148,6 +171,8 @@ class Trainer(object): # load model parameters try: self.get_model().load_state_dict(state['model'], strict=True) + if utils.has_parameters(self.get_criterion()): + self.get_criterion().load_state_dict(state['criterion'], strict=True) except Exception: raise Exception( 'Cannot load model parameters from checkpoint {}; ' @@ -164,7 +189,7 @@ class Trainer(object): # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] - assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ + assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \ 'Criterion does not match; please reset the optimizer (--reset-optimizer).' assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' @@ -322,9 +347,9 @@ class Trainer(object): # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( - logging_outputs, self.criterion + logging_outputs, self.get_criterion() ) - sample_size = self.task.grad_denom(sample_sizes, self.criterion) + sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception(( @@ -424,10 +449,10 @@ class Trainer(object): # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( - logging_output, self.criterion + logging_output, self.get_criterion() ) sample_size = self.task.grad_denom( - sample_size, self.criterion + sample_size, self.get_criterion() ) # update meters for validation @@ -477,6 +502,10 @@ class Trainer(object): """Get the (non-wrapped) model instance.""" return self._model + def get_criterion(self): + """Get the (non-wrapped) criterion instance.""" + return self._criterion + def get_meter(self, name): """Get a specific meter by name.""" if name not in self.meters: diff --git a/fairseq/utils.py b/fairseq/utils.py index 1b664cbfe..1af239443 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -351,3 +351,11 @@ def eval(model): model.eval() yield model.train(is_training) + + +def has_parameters(module): + try: + next(module.parameters()) + return True + except StopIteration: + return False