mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-20 18:41:02 +03:00
Optimizer support for least squares binarization
This commit is contained in:
parent
da2777d7cb
commit
58a7aca97f
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from .quant import (
|
||||
from fairseq.optim.quant import (
|
||||
Binarizer,
|
||||
get_qminmax,
|
||||
get_scale_init,
|
||||
@ -27,9 +27,14 @@ class AdaptedLinear(nn.Linear):
|
||||
|
||||
|
||||
class BinarizedLinear(AdaptedLinear):
|
||||
def __init__(self, weight_init: Tensor, bias_init: Optional[Tensor]):
|
||||
super().__init__(weight_init, bias_init)
|
||||
|
||||
self.register_buffer("v", torch.zeros(1))
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return Binarizer.apply(self._weight)
|
||||
return Binarizer.apply(self._weight, self.v, self.training)
|
||||
|
||||
|
||||
class LSQLinear(AdaptedLinear):
|
||||
@ -56,15 +61,12 @@ class QuantLS2Linear(AdaptedLinear):
|
||||
):
|
||||
super().__init__(weight_init, bias_init)
|
||||
|
||||
self.register_buffer("v1", torch.zeros(1))
|
||||
self.register_buffer("v2", torch.zeros(1))
|
||||
self.register_buffer("vs", torch.zeros(2))
|
||||
self.stride = stride
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return QuantLS2.apply(
|
||||
self._weight, self.v1, self.v2, self.stride, self.training
|
||||
)
|
||||
return QuantLS2.apply(self._weight, self.vs, self.training, self.stride)
|
||||
|
||||
def extra_repr(self):
|
||||
stride = self.stride
|
||||
|
@ -13,8 +13,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.optim import FairseqOptimizer, quant_utils, register_optimizer
|
||||
from fairseq.optim import FairseqOptimizer, register_optimizer
|
||||
from fairseq.optim.fused_adam import get_fused_adam_class
|
||||
from fairseq.optim.quant_optimizer import QuantOptimizer
|
||||
from omegaconf import II, OmegaConf
|
||||
|
||||
|
||||
@ -56,8 +57,8 @@ class FairseqAdam(FairseqOptimizer):
|
||||
super().__init__(cfg)
|
||||
quant_bits = getattr(cfg, "quant_bits", 32)
|
||||
quant_method = getattr(cfg, "quant_method", "none")
|
||||
if quant_bits > 2 and quant_method != "least-sq":
|
||||
raise NotImplementedError(f"{quant_method=} not supported yet")
|
||||
if 1 < quant_bits < 32 and quant_method == "parq":
|
||||
raise NotImplementedError
|
||||
|
||||
fused_adam_cls = get_fused_adam_class()
|
||||
use_fused_adam = (
|
||||
@ -121,7 +122,7 @@ class FairseqAdam(FairseqOptimizer):
|
||||
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
|
||||
|
||||
|
||||
class Adam(torch.optim.Optimizer):
|
||||
class Adam(QuantOptimizer):
|
||||
r"""Implements Adam algorithm.
|
||||
|
||||
This implementation is modified from torch.optim.Adam based on:
|
||||
@ -178,11 +179,6 @@ class Adam(torch.optim.Optimizer):
|
||||
def supports_flat_params(self):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def binarize_param(p_data_fp32) -> torch.Tensor:
|
||||
omega = quant_utils.estimate_omega(p_data_fp32)
|
||||
return quant_utils.scaled_sign_(p_data_fp32, omega)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -206,9 +202,11 @@ class Adam(torch.optim.Optimizer):
|
||||
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
amsgrad = group.get("amsgrad", False)
|
||||
apply_ste = (
|
||||
group["quant_method"] == "least-sq" and group["quant_bits"] < 32
|
||||
)
|
||||
|
||||
apply_ste = group["quant_bits"] < 32 and group["quant_method"] in {
|
||||
"lsq",
|
||||
"least-sq",
|
||||
}
|
||||
|
||||
p_data_fp32 = p.data
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
@ -265,8 +263,8 @@ class Adam(torch.optim.Optimizer):
|
||||
|
||||
p_buf.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
if apply_ste: # load latent_p then binarize
|
||||
self.binarize_param(p_data_fp32.copy_(p_buf))
|
||||
if apply_ste:
|
||||
self.quantize_param_(group, state, p_buf, p)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
@ -70,8 +70,9 @@ class FairseqMADGRAD(FairseqOptimizer):
|
||||
|
||||
quant_bits = getattr(cfg, "quant_bits", 32)
|
||||
quant_method = getattr(cfg, "quant_method", "none")
|
||||
if quant_bits > 2 and quant_method != "least-sq":
|
||||
raise NotImplementedError(f"{quant_method=} not supported yet")
|
||||
if 1 < quant_bits < 32 and quant_method == "parq":
|
||||
raise NotImplementedError
|
||||
|
||||
self._optimizer = MADGRAD(params, **self.optimizer_config)
|
||||
|
||||
@property
|
||||
|
@ -9,7 +9,7 @@ import torch.optim
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
|
||||
|
||||
from fairseq.optim import quant_utils
|
||||
from fairseq.optim.quant_optimizer import QuantOptimizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.optim.optimizer import _params_t
|
||||
@ -17,7 +17,7 @@ else:
|
||||
_params_t = Any
|
||||
|
||||
|
||||
class MADGRAD(torch.optim.Optimizer):
|
||||
class MADGRAD(QuantOptimizer):
|
||||
"""
|
||||
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
|
||||
Optimization.
|
||||
@ -175,7 +175,7 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
quant_method = group["quant_method"]
|
||||
|
||||
apply_par = quant_bits == 1 and quant_method == "parq"
|
||||
apply_ste = quant_bits == 1 and quant_method == "least-sq"
|
||||
apply_ste = quant_bits < 32 and quant_method in {"lsq", "least-sq"}
|
||||
|
||||
ck = 1 - momentum
|
||||
lamb = lr * math.pow(k + 1, 0.5)
|
||||
@ -253,11 +253,11 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
|
||||
if apply_par:
|
||||
rms.div_(state["lamb_sum"].add_(lamb)).clamp_(max=1)
|
||||
omega = quant_utils.estimate_omega(z)
|
||||
v = z.abs().mean()
|
||||
z = torch.where(
|
||||
z.abs() < omega.mul(rms),
|
||||
z.abs() < v.mul(rms),
|
||||
z.div(rms),
|
||||
torch.sign(z).mul_(omega),
|
||||
torch.sign(z).mul_(v),
|
||||
)
|
||||
state["z"].copy_(z)
|
||||
|
||||
@ -267,20 +267,19 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
alpha=1 - ck,
|
||||
)
|
||||
|
||||
if decouple_decay and decay != 0:
|
||||
p_old = (
|
||||
if apply_ste or decouple_decay and decay != 0:
|
||||
p_buf = (
|
||||
state["latent_p"] if apply_ste else p_data_fp32
|
||||
).clone()
|
||||
|
||||
if apply_ste:
|
||||
state["latent_p"].copy_(z)
|
||||
omega = quant_utils.estimate_omega(z)
|
||||
quant_utils.scaled_sign_(z, omega)
|
||||
self.quantize_param_(group, state, state["latent_p"], z)
|
||||
|
||||
p_data_fp32.copy_(z)
|
||||
|
||||
if decouple_decay and decay != 0:
|
||||
p_data_fp32.add_(p_old, alpha=-lr * decay)
|
||||
p_data_fp32.add_(p_buf, alpha=-lr * decay)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
@ -12,9 +12,18 @@ def get_qminmax(quant_bits) -> Tuple[int, int]:
|
||||
return qmin, qmax
|
||||
|
||||
|
||||
def get_quant_cls(quant_bits) -> Type[torch.autograd.Function]:
|
||||
def get_quant_cls(quant_method: str, quant_bits: int) -> Type[torch.autograd.Function]:
|
||||
assert quant_bits < 32
|
||||
return Binarizer if quant_bits == 1 else LSQClampRound
|
||||
if quant_method == "lsq":
|
||||
quant_cls = LSQClampRound
|
||||
elif quant_method == "least-sq":
|
||||
if quant_bits == 1:
|
||||
quant_cls = Binarizer
|
||||
elif quant_bits == 2:
|
||||
quant_cls = QuantLS2
|
||||
else:
|
||||
quant_cls = QuantLSGreedy
|
||||
return quant_cls
|
||||
|
||||
|
||||
def get_scale_init(input, qmax):
|
||||
@ -25,21 +34,21 @@ def l1_normalized(input, dim=None, keepdim=False) -> torch.Tensor:
|
||||
return input.abs().mean(dim=dim, keepdim=keepdim)
|
||||
|
||||
|
||||
def scaled_sign(input) -> torch.Tensor:
|
||||
omega = l1_normalized(input)
|
||||
return input.sign().mul_(omega)
|
||||
|
||||
|
||||
class Binarizer(torch.autograd.Function):
|
||||
"""1-bit binary quantization: https://arxiv.org/abs/1603.05279"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return scaled_sign(input)
|
||||
def forward(ctx, input, v, training: bool = False):
|
||||
if training:
|
||||
v.copy_(l1_normalized(input))
|
||||
|
||||
input_sgn = input.sign().mul_(v)
|
||||
return input_sgn
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None
|
||||
grad_input = grad_output if ctx.needs_input_grad[0] else None
|
||||
return grad_input, None, None
|
||||
|
||||
|
||||
class QuantLS2(torch.autograd.Function):
|
||||
@ -111,7 +120,7 @@ class QuantLS2(torch.autograd.Function):
|
||||
return v1
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, v1, v2, stride, training: bool = True):
|
||||
def forward(ctx, input, vs, training: bool = False, stride: int = 5):
|
||||
if training:
|
||||
# Simulate striding along randomly permuted columns
|
||||
n_col = input.size(1) // stride
|
||||
@ -119,14 +128,14 @@ class QuantLS2(torch.autograd.Function):
|
||||
strided_input = input[:, idx].abs()
|
||||
|
||||
v1s = QuantLS2.get_v1_cands(strided_input)
|
||||
v1.copy_(QuantLS2.compute_v1(strided_input, v1s))
|
||||
vs[0].copy_(QuantLS2.compute_v1(strided_input, v1s))
|
||||
|
||||
residual = QuantLS2.update_residual(input, v1)
|
||||
v2.copy_(l1_normalized(residual))
|
||||
residual = QuantLS2.update_residual(input, vs[0])
|
||||
vs[1].copy_(l1_normalized(residual))
|
||||
|
||||
# Use v1 and v2 to compute quantized input
|
||||
input_sgn = input.sign().mul(v1)
|
||||
input_sgn.add_((input - input_sgn).sign_().mul_(v2))
|
||||
input_sgn = input.sign().mul(vs[0])
|
||||
input_sgn.add_((input - input_sgn).sign_().mul_(vs[1]))
|
||||
return input_sgn
|
||||
|
||||
@staticmethod
|
||||
@ -220,3 +229,24 @@ class LSQClampRound(torch.autograd.Function):
|
||||
grad_target = LSQClampRound.get_grad_target(grad_output, mid_mask)
|
||||
|
||||
return grad_target, grad_scale, None, None
|
||||
|
||||
|
||||
class LSQNaive(torch.nn.Module):
|
||||
@staticmethod
|
||||
def grad_scale(x, scale):
|
||||
x_scale = x.mul(scale)
|
||||
return (x - x_scale).detach() + x_scale
|
||||
|
||||
@staticmethod
|
||||
def round_pass(x):
|
||||
x_round = x.round()
|
||||
return (x_round - x).detach() + x
|
||||
|
||||
@staticmethod
|
||||
def forward(target, scale, qmin, qmax):
|
||||
grad_scale_factor = 1.0 / math.sqrt(qmax * target.numel())
|
||||
s = LSQNaive.grad_scale(scale, grad_scale_factor)
|
||||
|
||||
quant_target = target.div(s).clamp_(qmin, qmax)
|
||||
quant_target = LSQNaive.round_pass(quant_target).mul_(s)
|
||||
return quant_target
|
62
fairseq/optim/quant_optimizer.py
Normal file
62
fairseq/optim/quant_optimizer.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from fairseq.optim.quant import (
|
||||
get_qminmax,
|
||||
get_quant_cls,
|
||||
get_scale_init,
|
||||
)
|
||||
|
||||
|
||||
class QuantOptimizer(torch.optim.Optimizer):
|
||||
def add_param_group(self, group: Dict[str, Any]) -> None:
|
||||
super().add_param_group(group)
|
||||
|
||||
quant_bits = group["quant_bits"]
|
||||
if quant_bits == 32:
|
||||
return
|
||||
|
||||
quant_method = group["quant_method"]
|
||||
if quant_method == "lsq":
|
||||
# Set LSQ constants
|
||||
group["qmin"], group["qmax"] = get_qminmax(quant_bits)
|
||||
|
||||
for p in group["params"]:
|
||||
old_p = p.clone().detach() # make full precision copy
|
||||
|
||||
group["quant_cls"] = get_quant_cls(quant_method, quant_bits)
|
||||
if quant_method == "lsq":
|
||||
# Initialize LSQ scale params
|
||||
scale = torch.empty(1, device=old_p.device, dtype=old_p.dtype)
|
||||
scale.copy_(get_scale_init(old_p, group["qmax"]))
|
||||
p.data.copy_(
|
||||
group["quant_cls"].apply(old_p, scale, group["qmin"], group["qmax"])
|
||||
)
|
||||
self.state[p]["scale"] = scale
|
||||
elif quant_method == "least-sq":
|
||||
vs = torch.empty(quant_bits, device=old_p.device, dtype=old_p.dtype)
|
||||
p.data.copy_(group["quant_cls"].apply(old_p, vs, True))
|
||||
self.state[p]["vs"] = vs
|
||||
self.state[p]["latent_p"] = old_p
|
||||
|
||||
@staticmethod
|
||||
def quantize_param_(group, state, p_buf, p):
|
||||
"""Quantize `p_buf` based on `group["quant_method"]`, saving into `p`."""
|
||||
quant_method = group["quant_method"]
|
||||
quant_cls = group["quant_cls"]
|
||||
if quant_method == "lsq":
|
||||
quant_target = p_buf.div(state["scale"])
|
||||
qmin, qmax = group["qmin"], group["qmax"]
|
||||
args = (quant_target, qmin, qmax)
|
||||
neg_mask, mid_mask, pos_mask = quant_cls.get_grad_masks(*args)
|
||||
|
||||
grad_scale = quant_cls.get_grad_scale(
|
||||
*args, neg_mask, mid_mask, pos_mask, p.grad
|
||||
)
|
||||
state["scale"].sub_(grad_scale.sum(), alpha=group["lr"])
|
||||
|
||||
p.copy_(quant_cls.apply(p_buf, state["scale"], qmin, qmax))
|
||||
p.grad.copy_(quant_cls.get_grad_target(p.grad, mid_mask))
|
||||
elif quant_method == "least-sq":
|
||||
p.copy_(quant_cls.apply(p_buf, state["vs"], True))
|
@ -1,15 +0,0 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def estimate_omega(latent_p):
|
||||
"""Compute layer-wise scaling factor estimated from latent weights.
|
||||
XNOR-Net: https://arxiv.org/abs/1603.05279
|
||||
"""
|
||||
return latent_p.norm(p=1).div(latent_p.numel())
|
||||
|
||||
|
||||
def scaled_sign_(latent_p, omega):
|
||||
"""In-place sign function scaled by layer-wise factor."""
|
||||
return latent_p.sign_().mul_(omega)
|
@ -314,16 +314,11 @@ class Trainer(object):
|
||||
is_quantized = ( # TODO: quantize `1 < quant_bits < 32` in optimizer
|
||||
optim_ste and quant_method != "none" and quant_bits < 32
|
||||
)
|
||||
if is_quantized and (
|
||||
self.cfg.optimizer["_name"] not in {"adam", "madgrad"}
|
||||
or quant_method == "lsq"
|
||||
or quant_bits > 1
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
if is_quantized:
|
||||
quant_param_names = []
|
||||
params, non_quant_params = self.model.split_quant_params(quant_param_names)
|
||||
else:
|
||||
self.cfg.optimizer.quant_bits = 32
|
||||
|
||||
if self.is_fsdp and self.cfg.common.fp16:
|
||||
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
|
||||
@ -358,11 +353,14 @@ class Trainer(object):
|
||||
"NOTE: your device may support faster training with --fp16 or --amp"
|
||||
)
|
||||
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
|
||||
|
||||
# Add extra param_group that disables relevant regularization
|
||||
if is_quantized:
|
||||
self._optimizer._optimizer.add_param_group(
|
||||
{"params": non_quant_params, "quant_bits": 32}
|
||||
{
|
||||
"params": non_quant_params,
|
||||
"quant_bits": 32,
|
||||
"quant_method": "none",
|
||||
}
|
||||
)
|
||||
|
||||
if self.is_fsdp:
|
||||
|
Loading…
Reference in New Issue
Block a user