Optimizer support for least squares binarization

This commit is contained in:
lisjin 2024-02-14 21:55:59 +00:00
parent da2777d7cb
commit 58a7aca97f
8 changed files with 149 additions and 74 deletions

View File

@ -3,7 +3,7 @@ import torch.nn as nn
from torch import Tensor
from typing import Optional
from .quant import (
from fairseq.optim.quant import (
Binarizer,
get_qminmax,
get_scale_init,
@ -27,9 +27,14 @@ class AdaptedLinear(nn.Linear):
class BinarizedLinear(AdaptedLinear):
def __init__(self, weight_init: Tensor, bias_init: Optional[Tensor]):
super().__init__(weight_init, bias_init)
self.register_buffer("v", torch.zeros(1))
@property
def weight(self):
return Binarizer.apply(self._weight)
return Binarizer.apply(self._weight, self.v, self.training)
class LSQLinear(AdaptedLinear):
@ -56,15 +61,12 @@ class QuantLS2Linear(AdaptedLinear):
):
super().__init__(weight_init, bias_init)
self.register_buffer("v1", torch.zeros(1))
self.register_buffer("v2", torch.zeros(1))
self.register_buffer("vs", torch.zeros(2))
self.stride = stride
@property
def weight(self):
return QuantLS2.apply(
self._weight, self.v1, self.v2, self.stride, self.training
)
return QuantLS2.apply(self._weight, self.vs, self.training, self.stride)
def extra_repr(self):
stride = self.stride

View File

@ -13,8 +13,9 @@ import torch
import torch.distributed as dist
import torch.optim
from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, quant_utils, register_optimizer
from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim.fused_adam import get_fused_adam_class
from fairseq.optim.quant_optimizer import QuantOptimizer
from omegaconf import II, OmegaConf
@ -56,8 +57,8 @@ class FairseqAdam(FairseqOptimizer):
super().__init__(cfg)
quant_bits = getattr(cfg, "quant_bits", 32)
quant_method = getattr(cfg, "quant_method", "none")
if quant_bits > 2 and quant_method != "least-sq":
raise NotImplementedError(f"{quant_method=} not supported yet")
if 1 < quant_bits < 32 and quant_method == "parq":
raise NotImplementedError
fused_adam_cls = get_fused_adam_class()
use_fused_adam = (
@ -121,7 +122,7 @@ class FairseqAdam(FairseqOptimizer):
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
class Adam(torch.optim.Optimizer):
class Adam(QuantOptimizer):
r"""Implements Adam algorithm.
This implementation is modified from torch.optim.Adam based on:
@ -178,11 +179,6 @@ class Adam(torch.optim.Optimizer):
def supports_flat_params(self):
return True
@staticmethod
def binarize_param(p_data_fp32) -> torch.Tensor:
omega = quant_utils.estimate_omega(p_data_fp32)
return quant_utils.scaled_sign_(p_data_fp32, omega)
def step(self, closure=None):
"""Performs a single optimization step.
@ -206,9 +202,11 @@ class Adam(torch.optim.Optimizer):
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group.get("amsgrad", False)
apply_ste = (
group["quant_method"] == "least-sq" and group["quant_bits"] < 32
)
apply_ste = group["quant_bits"] < 32 and group["quant_method"] in {
"lsq",
"least-sq",
}
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
@ -265,8 +263,8 @@ class Adam(torch.optim.Optimizer):
p_buf.addcdiv_(exp_avg, denom, value=-step_size)
if apply_ste: # load latent_p then binarize
self.binarize_param(p_data_fp32.copy_(p_buf))
if apply_ste:
self.quantize_param_(group, state, p_buf, p)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)

View File

@ -70,8 +70,9 @@ class FairseqMADGRAD(FairseqOptimizer):
quant_bits = getattr(cfg, "quant_bits", 32)
quant_method = getattr(cfg, "quant_method", "none")
if quant_bits > 2 and quant_method != "least-sq":
raise NotImplementedError(f"{quant_method=} not supported yet")
if 1 < quant_bits < 32 and quant_method == "parq":
raise NotImplementedError
self._optimizer = MADGRAD(params, **self.optimizer_config)
@property

View File

@ -9,7 +9,7 @@ import torch.optim
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
from fairseq.optim import quant_utils
from fairseq.optim.quant_optimizer import QuantOptimizer
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
@ -17,7 +17,7 @@ else:
_params_t = Any
class MADGRAD(torch.optim.Optimizer):
class MADGRAD(QuantOptimizer):
"""
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
Optimization.
@ -175,7 +175,7 @@ class MADGRAD(torch.optim.Optimizer):
quant_method = group["quant_method"]
apply_par = quant_bits == 1 and quant_method == "parq"
apply_ste = quant_bits == 1 and quant_method == "least-sq"
apply_ste = quant_bits < 32 and quant_method in {"lsq", "least-sq"}
ck = 1 - momentum
lamb = lr * math.pow(k + 1, 0.5)
@ -253,11 +253,11 @@ class MADGRAD(torch.optim.Optimizer):
if apply_par:
rms.div_(state["lamb_sum"].add_(lamb)).clamp_(max=1)
omega = quant_utils.estimate_omega(z)
v = z.abs().mean()
z = torch.where(
z.abs() < omega.mul(rms),
z.abs() < v.mul(rms),
z.div(rms),
torch.sign(z).mul_(omega),
torch.sign(z).mul_(v),
)
state["z"].copy_(z)
@ -267,20 +267,19 @@ class MADGRAD(torch.optim.Optimizer):
alpha=1 - ck,
)
if decouple_decay and decay != 0:
p_old = (
if apply_ste or decouple_decay and decay != 0:
p_buf = (
state["latent_p"] if apply_ste else p_data_fp32
).clone()
if apply_ste:
state["latent_p"].copy_(z)
omega = quant_utils.estimate_omega(z)
quant_utils.scaled_sign_(z, omega)
self.quantize_param_(group, state, state["latent_p"], z)
p_data_fp32.copy_(z)
if decouple_decay and decay != 0:
p_data_fp32.add_(p_old, alpha=-lr * decay)
p_data_fp32.add_(p_buf, alpha=-lr * decay)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)

View File

@ -12,9 +12,18 @@ def get_qminmax(quant_bits) -> Tuple[int, int]:
return qmin, qmax
def get_quant_cls(quant_bits) -> Type[torch.autograd.Function]:
def get_quant_cls(quant_method: str, quant_bits: int) -> Type[torch.autograd.Function]:
assert quant_bits < 32
return Binarizer if quant_bits == 1 else LSQClampRound
if quant_method == "lsq":
quant_cls = LSQClampRound
elif quant_method == "least-sq":
if quant_bits == 1:
quant_cls = Binarizer
elif quant_bits == 2:
quant_cls = QuantLS2
else:
quant_cls = QuantLSGreedy
return quant_cls
def get_scale_init(input, qmax):
@ -25,21 +34,21 @@ def l1_normalized(input, dim=None, keepdim=False) -> torch.Tensor:
return input.abs().mean(dim=dim, keepdim=keepdim)
def scaled_sign(input) -> torch.Tensor:
omega = l1_normalized(input)
return input.sign().mul_(omega)
class Binarizer(torch.autograd.Function):
"""1-bit binary quantization: https://arxiv.org/abs/1603.05279"""
@staticmethod
def forward(ctx, input):
return scaled_sign(input)
def forward(ctx, input, v, training: bool = False):
if training:
v.copy_(l1_normalized(input))
input_sgn = input.sign().mul_(v)
return input_sgn
@staticmethod
def backward(ctx, grad_output):
return grad_output if ctx.needs_input_grad[0] else None
grad_input = grad_output if ctx.needs_input_grad[0] else None
return grad_input, None, None
class QuantLS2(torch.autograd.Function):
@ -111,7 +120,7 @@ class QuantLS2(torch.autograd.Function):
return v1
@staticmethod
def forward(ctx, input, v1, v2, stride, training: bool = True):
def forward(ctx, input, vs, training: bool = False, stride: int = 5):
if training:
# Simulate striding along randomly permuted columns
n_col = input.size(1) // stride
@ -119,14 +128,14 @@ class QuantLS2(torch.autograd.Function):
strided_input = input[:, idx].abs()
v1s = QuantLS2.get_v1_cands(strided_input)
v1.copy_(QuantLS2.compute_v1(strided_input, v1s))
vs[0].copy_(QuantLS2.compute_v1(strided_input, v1s))
residual = QuantLS2.update_residual(input, v1)
v2.copy_(l1_normalized(residual))
residual = QuantLS2.update_residual(input, vs[0])
vs[1].copy_(l1_normalized(residual))
# Use v1 and v2 to compute quantized input
input_sgn = input.sign().mul(v1)
input_sgn.add_((input - input_sgn).sign_().mul_(v2))
input_sgn = input.sign().mul(vs[0])
input_sgn.add_((input - input_sgn).sign_().mul_(vs[1]))
return input_sgn
@staticmethod
@ -220,3 +229,24 @@ class LSQClampRound(torch.autograd.Function):
grad_target = LSQClampRound.get_grad_target(grad_output, mid_mask)
return grad_target, grad_scale, None, None
class LSQNaive(torch.nn.Module):
@staticmethod
def grad_scale(x, scale):
x_scale = x.mul(scale)
return (x - x_scale).detach() + x_scale
@staticmethod
def round_pass(x):
x_round = x.round()
return (x_round - x).detach() + x
@staticmethod
def forward(target, scale, qmin, qmax):
grad_scale_factor = 1.0 / math.sqrt(qmax * target.numel())
s = LSQNaive.grad_scale(scale, grad_scale_factor)
quant_target = target.div(s).clamp_(qmin, qmax)
quant_target = LSQNaive.round_pass(quant_target).mul_(s)
return quant_target

View File

@ -0,0 +1,62 @@
import torch
from typing import Any, Dict
from fairseq.optim.quant import (
get_qminmax,
get_quant_cls,
get_scale_init,
)
class QuantOptimizer(torch.optim.Optimizer):
def add_param_group(self, group: Dict[str, Any]) -> None:
super().add_param_group(group)
quant_bits = group["quant_bits"]
if quant_bits == 32:
return
quant_method = group["quant_method"]
if quant_method == "lsq":
# Set LSQ constants
group["qmin"], group["qmax"] = get_qminmax(quant_bits)
for p in group["params"]:
old_p = p.clone().detach() # make full precision copy
group["quant_cls"] = get_quant_cls(quant_method, quant_bits)
if quant_method == "lsq":
# Initialize LSQ scale params
scale = torch.empty(1, device=old_p.device, dtype=old_p.dtype)
scale.copy_(get_scale_init(old_p, group["qmax"]))
p.data.copy_(
group["quant_cls"].apply(old_p, scale, group["qmin"], group["qmax"])
)
self.state[p]["scale"] = scale
elif quant_method == "least-sq":
vs = torch.empty(quant_bits, device=old_p.device, dtype=old_p.dtype)
p.data.copy_(group["quant_cls"].apply(old_p, vs, True))
self.state[p]["vs"] = vs
self.state[p]["latent_p"] = old_p
@staticmethod
def quantize_param_(group, state, p_buf, p):
"""Quantize `p_buf` based on `group["quant_method"]`, saving into `p`."""
quant_method = group["quant_method"]
quant_cls = group["quant_cls"]
if quant_method == "lsq":
quant_target = p_buf.div(state["scale"])
qmin, qmax = group["qmin"], group["qmax"]
args = (quant_target, qmin, qmax)
neg_mask, mid_mask, pos_mask = quant_cls.get_grad_masks(*args)
grad_scale = quant_cls.get_grad_scale(
*args, neg_mask, mid_mask, pos_mask, p.grad
)
state["scale"].sub_(grad_scale.sum(), alpha=group["lr"])
p.copy_(quant_cls.apply(p_buf, state["scale"], qmin, qmax))
p.grad.copy_(quant_cls.get_grad_target(p.grad, mid_mask))
elif quant_method == "least-sq":
p.copy_(quant_cls.apply(p_buf, state["vs"], True))

View File

@ -1,15 +0,0 @@
import logging
logger = logging.getLogger(__name__)
def estimate_omega(latent_p):
"""Compute layer-wise scaling factor estimated from latent weights.
XNOR-Net: https://arxiv.org/abs/1603.05279
"""
return latent_p.norm(p=1).div(latent_p.numel())
def scaled_sign_(latent_p, omega):
"""In-place sign function scaled by layer-wise factor."""
return latent_p.sign_().mul_(omega)

View File

@ -314,16 +314,11 @@ class Trainer(object):
is_quantized = ( # TODO: quantize `1 < quant_bits < 32` in optimizer
optim_ste and quant_method != "none" and quant_bits < 32
)
if is_quantized and (
self.cfg.optimizer["_name"] not in {"adam", "madgrad"}
or quant_method == "lsq"
or quant_bits > 1
):
raise NotImplementedError
if is_quantized:
quant_param_names = []
params, non_quant_params = self.model.split_quant_params(quant_param_names)
else:
self.cfg.optimizer.quant_bits = 32
if self.is_fsdp and self.cfg.common.fp16:
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
@ -358,11 +353,14 @@ class Trainer(object):
"NOTE: your device may support faster training with --fp16 or --amp"
)
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
# Add extra param_group that disables relevant regularization
if is_quantized:
self._optimizer._optimizer.add_param_group(
{"params": non_quant_params, "quant_bits": 32}
{
"params": non_quant_params,
"quant_bits": 32,
"quant_method": "none",
}
)
if self.is_fsdp: