composite optimizer

Summary:
this adds a composite optimizer and pass through learning rate scheduler that allows fairseq models to have separate optimizers (that can optionally have separate lr schedulers) for different parameters. to use this, you add a "param_group" field to the parameters you wish to be optimized separately (the rest of the params get automatically placed into a "default" group), then specify a composite optimizer with nested optimizers (and, optionally, lr schedulers) for each group name (see example below).

for fp16 training this requires setting fp16_no_flatten_grads to true

one possible area for future improvement is to automatically create param groups based on module names, but this is to be discussed

for example, i can modify wav2vec2 model and add
```python
for p in self.pos_conv.parameters():
    p.param_group = "pos_conv"
```
in the TransformerEncoder class, just after pos_conv is created

then i create the following config:

```yaml
# package _group_

hydra:
  run:
    dir: .
  job_logging:
    disable_existing_loggers: false

common:
  fp16: true
  log_format: json
  log_interval: 10
  fp16_no_flatten_grads: true

checkpoint:
  save_interval_updates: 20
  keep_interval_updates: 1
  no_epoch_checkpoints: true
  no_save: false

Reviewed By: myleott

Differential Revision: D25152032

fbshipit-source-id: c73ff95146ecc2a04660c67bcad02b637c5c5098
This commit is contained in:
Alexei Baevski 2020-12-04 17:34:08 -08:00 committed by Facebook GitHub Bot
parent d5218f8827
commit ba4f54267a
14 changed files with 311 additions and 35 deletions

View File

@ -8,7 +8,7 @@ import math
from argparse import Namespace
from dataclasses import dataclass, field
from omegaconf import II
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F

View File

@ -3,14 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import logging
from dataclasses import dataclass, field
from typing import Dict, List
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.distributed_utils import get_data_parallel_world_size
logger = logging.getLogger(__name__)
@ -80,6 +79,7 @@ class ModelCriterion(FairseqCriterion):
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
"_world_size": 1,
}
for lk in self.log_keys:
@ -113,9 +113,12 @@ class ModelCriterion(FairseqCriterion):
"ntokens",
"nsentences",
"sample_size",
"_world_size",
}
world_size = get_data_parallel_world_size()
world_size = utils.item(
sum(log.get("_world_size", 0) for log in logging_outputs)
)
for k in logging_outputs[0]:
if k not in builtin_keys:

View File

@ -4,19 +4,19 @@
# LICENSE file in the root directory of this source tree.
import ast
import os
import inspect
import logging
import os
import re
from argparse import ArgumentError, ArgumentParser, Namespace
from dataclasses import _MISSING_TYPE, MISSING
from enum import Enum
import inspect
from typing import Any, Dict, List, Tuple, Type
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import FairseqConfig
from hydra.experimental import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict
logger = logging.getLogger(__name__)
@ -218,7 +218,9 @@ def _override_attr(
isinstance(val, str)
and not val.startswith("${") # not interpolation
and field_type != str
and (not inspect.isclass(field_type) or not issubclass(field_type, Enum)) # not choices enum
and (
not inspect.isclass(field_type) or not issubclass(field_type, Enum)
) # not choices enum
):
# upgrade old models that stored complex parameters as string
val = ast.literal_eval(val)
@ -438,9 +440,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass):
dc_instance = DictConfig(dc)
dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"]
with open_dict(dc_instance):
cfg = OmegaConf.merge(dc_instance, cfg)
OmegaConf.set_struct(cfg, True)
return cfg
merged_cfg = OmegaConf.merge(dc, cfg)
merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
OmegaConf.set_struct(merged_cfg, True)
return merged_cfg

View File

@ -161,8 +161,9 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
elif cfg.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs
assert cfg.distributed_world_size <= torch.cuda.device_count(), \
f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
assert (
cfg.distributed_world_size <= torch.cuda.device_count()
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
port = random.randint(10000, 20000)
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
@ -376,8 +377,10 @@ def get_world_size(group):
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return len(my_group)
else:
elif torch.distributed.is_initialized():
return dist.get_world_size(group=group)
else:
return 1
def get_global_group():
@ -416,6 +419,7 @@ def get_data_parallel_group():
global _USE_MEGATRON
if _USE_MEGATRON:
from fairseq.model_parallel.megatron import mpu
return mpu.get_data_parallel_group()
else:
return get_global_group()
@ -435,6 +439,7 @@ def get_model_parallel_group():
global _USE_MEGATRON
if _USE_MEGATRON:
from fairseq.model_parallel.megatron import mpu
return mpu.get_model_parallel_group()
else:
return None

View File

@ -63,7 +63,11 @@ def build_model(cfg: FairseqDataclass, task):
cfg = cfg[model_type]
else:
raise Exception(
"Could not infer model type from directory. Please add _name field to indicate model type"
"Could not infer model type from directory. Please add _name field to indicate model type. "
"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
if model_type in ARCH_MODEL_REGISTRY:
@ -81,7 +85,13 @@ def build_model(cfg: FairseqDataclass, task):
else:
cfg = merge_with_parent(dc(), cfg)
assert model is not None, f"Could not infer model type from {cfg}"
assert model is not None, (
f"Could not infer model type from {cfg}. "
f"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
return model.build_model(cfg, task)

View File

@ -8,11 +8,14 @@ from torch import nn
class SamePad(nn.Module):
def __init__(self, kernel_size):
def __init__(self, kernel_size, causal=False):
super().__init__()
self.remove = kernel_size % 2 == 0
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove:
x = x[:, :, :-1]
if self.remove > 0:
x = x[:, :, : -self.remove]
return x

183
fairseq/optim/composite.py Normal file
View File

@ -0,0 +1,183 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional
import torch.optim
from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer
from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler
from omegaconf import II, open_dict
logger = logging.getLogger(__name__)
@dataclass
class OptimizerAndSchedulerConfig(FairseqDataclass):
optimizer: Any = None
lr_scheduler: Optional[Any] = None
lr: List[float] = II("optimization.lr")
@dataclass
class CompositeOptimizerConfig(FairseqDataclass):
groups: Dict[str, OptimizerAndSchedulerConfig] = field(
default_factory=lambda: {},
metadata={
"help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. "
"Configures a different optimizer and (optionally) lr scheduler for each parameter group"
},
)
@register_optimizer("composite", dataclass=CompositeOptimizerConfig)
class FairseqCompositeOptimizer(FairseqOptimizer):
optimizers: Dict[str, FairseqOptimizer] = {}
lr_schedulers: Dict[str, FairseqLRScheduler] = {}
lr_scheduler: FairseqLRScheduler = None
_optimizer: torch.optim.Optimizer
def __init__(self, cfg: CompositeOptimizerConfig, params):
super().__init__(cfg)
assert (
len(params) > 1
), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)"
groupped_params = defaultdict(list)
for p in params:
group = getattr(p, "param_group", "default")
groupped_params[group].append(p)
assert groupped_params.keys() == cfg.groups.keys(), (
f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! "
"Try setting 'param_group' on your parameters in the model."
)
for group, group_params in groupped_params.items():
group_cfg = cfg.groups[group]
with open_dict(group_cfg):
group_cfg.optimizer.lr = group_cfg.lr
group_cfg.lr_scheduler.lr = group_cfg.lr
self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params)
if group_cfg.lr_scheduler is not None:
self.lr_schedulers[group] = build_lr_scheduler(
group_cfg.lr_scheduler, self.optimizers[group]
)
if len(self.lr_schedulers) > 0:
assert len(self.lr_schedulers) == len(self.optimizers), (
f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. "
f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}"
)
self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers)
self._optimizer = CompositeOptimizer(self.optimizers)
@property
def supports_groups(self):
return True
@property
def param_groups(self):
for opt in self.optimizers.values():
for group in opt.param_groups:
yield group
def get_lr(self):
"""Return the current learning rate."""
k = (
"default"
if "default" in self.optimizers
else next(iter(self.optimizers.keys()))
)
return self.optimizers[k].param_groups[0]["lr"]
def state_dict(self):
"""Return the LR scheduler state dict."""
return {k: s.state_dict() for k, s in self.optimizers.items()}
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an LR scheduler state dict."""
for k, state in state_dict.items():
if k not in self.optimizers:
# skip extra keys like "loss_scale" added by fp16 optimizer
continue
overrides = (
optimizer_overrides[k]
if isinstance(optimizer_overrides, dict) and k in optimizer_overrides
else None
)
self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides)
class CompositeOptimizer(torch.optim.Optimizer):
def __init__(self, optimizers: Dict[str, FairseqOptimizer]):
self.optimizers = optimizers
@property
def supports_memory_efficient_fp16(self):
return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values())
@property
def supports_flat_params(self):
return all(o.supports_flat_params for o in self.optimizers.values())
def step(self, closure=None, groups=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for k, opt in self.optimizers.items():
if groups is None or k in groups:
opt.step()
return loss
def zero_grad(self):
for opt in self.optimizers.values():
opt.zero_grad()
class CompositeLRScheduler(FairseqLRScheduler):
def __init__(self, lr_schedulers):
super().__init__(None, None)
self.lr_schedulers = lr_schedulers
def state_dict(self):
"""Return the LR scheduler state dict."""
return {k: s.state_dict() for k, s in self.lr_schedulers.items()}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
for k, state in state_dict.items():
self.lr_schedulers[k].load_state_dict(state)
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
for s in self.lr_schedulers.values():
s.step_begin_epoch(epoch)
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
for s in self.lr_schedulers.values():
s.step(epoch)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()}

View File

@ -109,14 +109,21 @@ class FairseqOptimizer(object):
"""Clips gradient norm."""
return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
def step(self, closure=None, scale=1.0):
def step(self, closure=None, scale=1.0, groups=None):
"""Performs a single optimization step."""
if self.supports_step_with_scale:
self.optimizer.step(closure, scale=scale)
if self.supports_groups:
self.optimizer.step(closure, scale=scale, groups=groups)
else:
self.optimizer.step(closure, scale=scale)
else:
if scale != 1.0:
self.multiply_grads(1.0 / scale)
self.optimizer.step(closure)
if self.supports_groups:
self.optimizer.step(closure, groups=groups)
else:
self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
@ -136,6 +143,12 @@ class FairseqOptimizer(object):
return self.optimizer.supports_step_with_scale
return False
@property
def supports_groups(self):
if hasattr(self.optimizer, "supports_groups"):
return self.optimizer.supports_groups
return False
@property
def supports_flat_params(self):
"""

View File

@ -65,6 +65,8 @@ class _FP16OptimizerMixin(object):
for p in params:
p32 = torch.nn.Parameter(p.data.float())
p32.grad = torch.zeros_like(p32.data)
if hasattr(p, "param_group"):
p32.param_group = p.param_group
fp32_params.append(p32)
return fp32_params
@ -198,15 +200,15 @@ class _FP16OptimizerMixin(object):
return grad_norm
def step(self, closure=None):
def step(self, closure=None, groups=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
else:
self._unscale_grads()
self.fp32_optimizer.step(closure)
self.fp32_optimizer.step(closure, groups=groups)
if self.scaler is not None:
self.scaler.update()
@ -303,6 +305,10 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
def optimizer(self, optimizer):
self.fp32_optimizer.optimizer = optimizer
@property
def lr_scheduler(self):
return getattr(self.fp32_optimizer, "lr_scheduler", None)
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
@ -416,14 +422,14 @@ class _MemoryEfficientFP16OptimizerMixin(object):
return grad_norm
def step(self, closure=None):
def step(self, closure=None, groups=None):
"""Performs a single optimization step."""
if getattr(self, "supports_step_with_scale", False):
# NOTE(msb) optimizer divides by scale factor
self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor))
self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
else:
self._unscale_grads()
self.wrapped_optimizer.step(closure)
self.wrapped_optimizer.step(closure, groups=groups)
if self.scaler is not None:
self.scaler.update()
@ -514,6 +520,10 @@ class MemoryEfficientFP16Optimizer(
def optimizer_config(self):
return self.wrapped_optimizer.optimizer_config
@property
def lr_scheduler(self):
return getattr(self.wrapped_optimizer, "lr_scheduler", None)
def get_lr(self):
return self.wrapped_optimizer.get_lr()

View File

@ -12,7 +12,7 @@ from fairseq.optim import FairseqOptimizer
class FairseqLRScheduler(object):
def __init__(self, cfg, optimizer):
super().__init__()
if not isinstance(optimizer, FairseqOptimizer):
if optimizer is not None and not isinstance(optimizer, FairseqOptimizer):
raise ValueError("optimizer must be an instance of FairseqOptimizer")
self.cfg = cfg
self.optimizer = optimizer

View File

@ -0,0 +1,39 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class PassThroughScheduleConfig(FairseqDataclass):
pass
@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig)
class PassThroughScheduleSchedule(FairseqLRScheduler):
"""Delegate lr scheduling to the optimizer."""
def __init__(self, cfg: PassThroughScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
assert (
hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
), "Pass-through schedule can only be used with optimizers with their own schedulers"
def state_dict(self):
return self.optimizer.lr_scheduler.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.lr_scheduler.load_state_dict(state_dict)
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.lr_scheduler.step_update(num_updates)

View File

@ -75,7 +75,7 @@ class NAG(Optimizer):
momentum = group["momentum"]
lr = group["lr"]
lr_old = group.get("lr_old", lr)
lr_correct = lr / lr_old
lr_correct = lr / lr_old if lr_old > 0 else lr
for p in group["params"]:
if p.grad is None:

View File

@ -438,6 +438,9 @@ class FairseqTask(object):
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def optimizer_step(self, optimizer, model, update_num):
optimizer.step()
def build_dataset_for_inference(
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
) -> torch.utils.data.Dataset:

View File

@ -637,7 +637,9 @@ class Trainer(object):
with torch.autograd.profiler.record_function("optimizer"):
# take an optimization step
self.optimizer.step()
self.task.optimizer_step(
self.optimizer, model=self.model, update_num=self.get_num_updates()
)
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print
@ -827,7 +829,12 @@ class Trainer(object):
def lr_step_update(self):
"""Update the learning rate after each update."""
new_lr = self.lr_scheduler.step_update(self.get_num_updates())
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
if isinstance(new_lr, dict):
for k, v in new_lr.items():
metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
new_lr = new_lr.get("default", next(iter(new_lr.values())))
else:
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
return new_lr
def get_lr(self):