Migrate remaining LR schedulers (#1448)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1448

Test Plan: Imported from OSS

Reviewed By: alexeib

Differential Revision: D25092150

Pulled By: myleott

fbshipit-source-id: fd066a0eba388bb0c344082a8fa1132974d53d40
This commit is contained in:
Myle Ott 2020-11-20 05:59:25 -08:00 committed by Facebook GitHub Bot
parent 7171cdec5b
commit 40fbb37443
9 changed files with 191 additions and 158 deletions

View File

@ -227,7 +227,11 @@ def _override_attr(
if isinstance(val, tuple):
val = list(val)
if getattr(v.type, "__origin__", None) is List:
if (
getattr(v.type, "__origin__", None) is List
# skip interpolation
and not (isinstance(val, str) and val.startswith("${"))
):
# if type is int but val is float, then we will crash later - try to convert here
t_args = v.type.__args__
if len(t_args) == 1:

View File

@ -8,14 +8,14 @@ from collections import Collection
from dataclasses import dataclass, field
from typing import List
from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from omegaconf import II
from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class CosineConfig(FairseqDataclass):
class CosineLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
@ -23,11 +23,11 @@ class CosineConfig(FairseqDataclass):
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is args.lr"
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
max_lr: float = field(
default=1.0, metadata={"help": "max learning rate, must be more than args.lr"}
default=1.0, metadata={"help": "max learning rate, must be more than cfg.lr"}
)
t_mult: float = field(
default=1.0, metadata={"help": "factor to grow the length of each period"}
@ -38,13 +38,12 @@ class CosineConfig(FairseqDataclass):
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
# TODO common var for parent class
lr: List[float] = II("optimization.lr")
max_update: int = II("optimization.max_update")
@register_lr_scheduler("cosine", dataclass=CosineConfig)
class CosineSchedule(FairseqLRScheduler):
@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig)
class CosineLRSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
@ -55,7 +54,7 @@ class CosineSchedule(FairseqLRScheduler):
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
lr = lrs[update_num]
After warmup::
@ -67,9 +66,7 @@ class CosineSchedule(FairseqLRScheduler):
after every iteration.
"""
def __init__(
self, cfg: DictConfig, fairseq_optimizer
):
def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
super().__init__(cfg, fairseq_optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
@ -78,11 +75,7 @@ class CosineSchedule(FairseqLRScheduler):
)
warmup_end_lr = cfg.max_lr
lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = lr
@ -100,10 +93,8 @@ class CosineSchedule(FairseqLRScheduler):
self.period = cfg.max_update - cfg.warmup_updates
if cfg.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# linearly warmup for the first cfg.warmup_updates
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
else:
self.lr_step = 1

View File

@ -6,8 +6,7 @@
from argparse import Namespace
from fairseq.dataclass.utils import gen_parser_from_dataclass
from .. import FairseqOptimizer
from fairseq.optim import FairseqOptimizer
class FairseqLRScheduler(object):

View File

@ -3,37 +3,44 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import LegacyFairseqLRScheduler, register_lr_scheduler
from dataclasses import dataclass, field
from typing import Optional, List
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler("fixed")
class FixedSchedule(LegacyFairseqLRScheduler):
@dataclass
class FixedLRScheduleConfig(FairseqDataclass):
force_anneal: Optional[int] = field(
default=None,
metadata={"help": "force annealing at specified epoch"},
)
lr_shrink: float = field(
default=0.1,
metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"},
)
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig)
class FixedLRSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
def __init__(self, cfg: FixedLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
# set defaults
args.warmup_updates = getattr(args, "warmup_updates", 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1.0 / args.warmup_updates
self.lr = cfg.lr[0]
if cfg.warmup_updates > 0:
self.warmup_factor = 1.0 / cfg.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch (epochs start at 1)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
# fmt: on
def state_dict(self):
return {"lr": self.lr}
@ -42,14 +49,14 @@ class FixedSchedule(LegacyFairseqLRScheduler):
self.lr = state_dict["lr"]
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
lrs = self.cfg.lr
if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch - 1, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (
epoch + 1 - self.args.force_anneal
next_lr = lrs[-1] * self.cfg.lr_shrink ** (
epoch + 1 - self.cfg.force_anneal
)
return next_lr
@ -61,8 +68,8 @@ class FixedSchedule(LegacyFairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates)
if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
else:
self.optimizer.set_lr(self.lr)

View File

@ -7,14 +7,14 @@ from collections import Collection
from dataclasses import dataclass, field
from typing import List
from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from omegaconf import II
from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class InverseSquareRootScheduleConfig(FairseqDataclass):
class InverseSquareRootLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=4000,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
@ -22,14 +22,13 @@ class InverseSquareRootScheduleConfig(FairseqDataclass):
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is args.lr"
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
# TODO common vars at parent class
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig)
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig)
class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
@ -40,36 +39,28 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
lr = lrs[update_num]
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
decay_factor = cfg.lr * sqrt(cfg.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""
def __init__(self, cfg: DictConfig, optimizer):
def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = (
0 if cfg.warmup_updates > 0 else warmup_end_lr
)
cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
# linearly warmup for the first args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# linearly warmup for the first cfg.warmup_updates
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5

View File

@ -8,11 +8,11 @@ from typing import Optional, List
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class PolynomialDecayScheduleConfig(FairseqDataclass):
class PolynomialDecayLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
@ -36,13 +36,11 @@ class PolynomialDecayScheduleConfig(FairseqDataclass):
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayScheduleConfig)
class PolynomialDecaySchedule(FairseqLRScheduler):
@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig)
class PolynomialDecayLRSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
cfg: PolynomialDecayScheduleConfig
def __init__(self, cfg: PolynomialDecayScheduleConfig, optimizer):
def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
assert cfg.total_num_update > 0

View File

@ -3,13 +3,59 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import List
import torch.optim.lr_scheduler
from omegaconf import II
from . import LegacyFairseqLRScheduler, register_lr_scheduler
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler("reduce_lr_on_plateau")
class ReduceLROnPlateau(LegacyFairseqLRScheduler):
@dataclass
class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass):
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
lr_threshold: float = field(
default=1e-4,
metadata={
"help": (
"threshold for measuring the new optimum, to only focus on "
"significant changes"
)
},
)
lr_patience: int = field(
default=0,
metadata={
"help": (
"number of epochs with no improvement after which learning rate will "
"be reduced"
)
},
)
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
lr: List[float] = II("optimization.lr")
maximize_best_checkpoint_metric: bool = II(
"checkpoint.maximize_best_checkpoint_metric"
)
@register_lr_scheduler(
"reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig
)
class ReduceLROnPlateauLRSchedule(FairseqLRScheduler):
"""
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase
@ -21,61 +67,43 @@ class ReduceLROnPlateau(LegacyFairseqLRScheduler):
During warmup::
lrs = torch.linspace(
args.warmup_init_lr, args.lr, args.warmup_updates
cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates
)
lr = lrs[update_num]
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
if len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
" Consider --lr-scheduler=fixed instead."
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer,
patience=args.lr_patience,
factor=args.lr_shrink,
mode="max" if args.maximize_best_checkpoint_metric else "min",
threshold=args.lr_threshold,
patience=cfg.lr_patience,
factor=cfg.lr_shrink,
mode="max" if cfg.maximize_best_checkpoint_metric else "min",
threshold=cfg.lr_threshold,
)
warmup_end_lr = args.lr[0]
# if no warm up, sets initial lr to be args.lr[0]
if args.warmup_init_lr < 0:
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
warmup_end_lr = cfg.lr[0]
# if no warm up, sets initial lr to be cfg.lr[0]
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
# linearly warmup for the first args.warmup_updates
if args.warmup_updates > 0:
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
# linearly warmup for the first cfg.warmup_updates
if cfg.warmup_updates > 0:
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
# this flag is either set from arg when no warm up, or set by
# step_update() when warmup finishes
self.warmup_end = True if args.warmup_updates <= 0 else False
self.warmup_end = True if cfg.warmup_updates <= 0 else False
# initial learning rate
# this self.lr is used only during init and/or warm up period
self.lr = args.warmup_init_lr
self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT',
help='threshold for measuring the new optimum, '
'to only focus on significant changes')
parser.add_argument('--lr-patience', default=0, type=int,
help='number of epochs with no improvement after which '
'learning rate will be reduced')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
# fmt: on
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
@ -104,9 +132,9 @@ class ReduceLROnPlateau(LegacyFairseqLRScheduler):
"""
Update the learning rate after each update."""
# if there is warmup
if self.args.warmup_updates > 0:
if num_updates <= self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
if self.cfg.warmup_updates > 0:
if num_updates <= self.cfg.warmup_updates:
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
self.optimizer.set_lr(self.lr)
else:
if self.warmup_end is False:

View File

@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
@ -29,8 +28,12 @@ class TriStageLRScheduleConfig(FairseqDataclass):
)
phase_ratio: Optional[Tuple[float, float, float]] = field(
default=None,
metadata={"help": "if set, automatically sets warmup/hold/decay steps to the ratio specified here "
"from max_updates. the ratios must add up to 1.0"},
metadata={
"help": (
"if set, automatically sets warmup/hold/decay steps to the ratio "
"specified here from max_updates. the ratios must add up to 1.0"
)
},
)
init_lr_scale: float = field(
default=0.01,
@ -42,7 +45,7 @@ class TriStageLRScheduleConfig(FairseqDataclass):
)
max_update: float = II("optimization.max_update")
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig)
class TriStageLRScheduleConfig(FairseqLRScheduler):
@ -90,6 +93,7 @@ class TriStageLRScheduleConfig(FairseqLRScheduler):
"Cannot use a fixed learning rate schedule with tri-stage lr."
" Consider --lr-scheduler=fixed instead."
)
assert cfg.max_update > 0
# calculate LR at each point
self.peak_lr = cfg.lr[0]
@ -97,7 +101,7 @@ class TriStageLRScheduleConfig(FairseqLRScheduler):
self.final_lr = cfg.final_lr_scale * cfg.lr[0]
if cfg.phase_ratio is not None:
assert sum(cfg.phase_ratio) == 1, 'phase ratios must add up to 1'
assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1"
self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0])
self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1])
self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2])
@ -105,8 +109,10 @@ class TriStageLRScheduleConfig(FairseqLRScheduler):
self.warmup_steps = cfg.warmup_steps
self.hold_steps = cfg.hold_steps
self.decay_steps = cfg.decay_steps
assert self.warmup_steps + self.hold_steps + self.decay_steps > 0, "please specify steps or phase_ratio"
assert (
self.warmup_steps + self.hold_steps + self.decay_steps > 0
), "please specify steps or phase_ratio"
self.warmup_rate = (
(self.peak_lr - self.init_lr) / self.warmup_steps

View File

@ -4,52 +4,61 @@
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List
from . import LegacyFairseqLRScheduler, register_lr_scheduler
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler("triangular")
class TriangularSchedule(LegacyFairseqLRScheduler):
@dataclass
class TriangularLRScheduleConfig(FairseqDataclass):
max_lr: float = field(
default="???", metadata={"help": "max learning rate, must be more than cfg.lr"}
)
lr_period_updates: float = field(
default=5000,
metadata={"help": "initial number of updates per period (cycle length)"},
)
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
shrink_min: bool = field(
default=False, metadata={"help": "if set, also shrinks min lr"}
)
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig)
class TriangularLRSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
def __init__(self, cfg: TriangularLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
if len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with triangular."
" Consider --lr-scheduler=fixed instead."
)
lr = args.lr[0]
lr = cfg.lr[0]
assert args.max_lr > lr, "max_lr must be more than lr"
assert cfg.max_lr > lr, "max_lr must be more than lr"
self.min_lr = lr
self.max_lr = args.max_lr
self.stepsize = args.lr_period_updates // 2
self.lr_shrink = args.lr_shrink
self.shrink_min = args.shrink_min
self.max_lr = cfg.max_lr
self.stepsize = cfg.lr_period_updates // 2
self.lr_shrink = cfg.lr_shrink
self.shrink_min = cfg.shrink_min
# initial learning rate
self.lr = self.min_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr')
# fmt: on
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)