Add regularization for multihead attention module and ffn module

Summary: [Fairseq] Add regularization for multihead attention module and ffn module

Reviewed By: dianaml0

Differential Revision: D32441521

fbshipit-source-id: c648c1f8ec1a3310ba90c4952cdd40a21b959d26
This commit is contained in:
Liang Tan 2021-12-30 02:01:11 -08:00 committed by Facebook GitHub Bot
parent 7fddb9d960
commit 2762a1cfef
3 changed files with 88 additions and 9 deletions

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@ -54,18 +54,32 @@ class SentencePredictionCriterion(FairseqCriterion):
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
task_loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
task_loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
logging_output = {}
loss = task_loss
# mha & ffn regularization update
if hasattr(model.args, "mha_reg_scale_factor") and model.args.mha_reg_scale_factor != 0.0:
mha_reg_loss = model._get_adaptive_head_loss()
loss += mha_reg_loss
logging_output.update({"mha_reg_loss": mha_reg_loss})
if hasattr(model.args, "ffn_reg_scale_factor") and model.args.ffn_reg_scale_factor != 0.0:
ffn_reg_loss = model._get_adaptive_ffn_loss()
loss += ffn_reg_loss
logging_output.update({"ffn_reg_loss": ffn_reg_loss})
logging_output.update(
{
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
)
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
@ -79,10 +93,20 @@ class SentencePredictionCriterion(FairseqCriterion):
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs)
ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if mha_reg_loss_sum:
metrics.log_scalar(
"mha_reg_loss", mha_reg_loss_sum / sample_size / math.log(2), sample_size, round=3
)
if ffn_reg_loss_sum:
metrics.log_scalar(
"ffn_reg_loss", ffn_reg_loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3

View File

@ -185,6 +185,23 @@ class RobertaModel(FairseqEncoderModel):
"--offload-activations are passed."
),
)
# args for AdaPruning
# In short, it adds regularizarion for the multihead attention module and feed forward neural nets
# For more details, please refer to the paper https://openreview.net/forum?id=_CMSV7FTzGI
parser.add_argument(
"--mha-reg-scale-factor",
type=float,
metavar="D",
default=0.0,
help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375",
)
parser.add_argument(
"--ffn-reg-scale-factor",
type=float,
metavar="D",
default=0.0,
help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375",
)
@classmethod
def build_model(cls, args, task):
@ -227,6 +244,29 @@ class RobertaModel(FairseqEncoderModel):
x = self.classification_heads[classification_head_name](x)
return x, extra
def _get_adaptive_head_loss(self):
norm_loss = 0
scaling = float(self.args.mha_reg_scale_factor)
for layer in self.encoder.sentence_encoder.layers:
norm_loss_layer = 0
for i in range(layer.self_attn.num_heads):
start_idx = i * layer.self_attn.head_dim
end_idx = (i + 1) * layer.self_attn.head_dim
norm_loss_layer += scaling * (torch.sum(torch.abs(layer.self_attn.q_proj.weight[start_idx:end_idx, ])) + torch.sum(torch.abs(layer.self_attn.q_proj.bias[start_idx:end_idx])))
norm_loss_layer += scaling * (torch.sum(torch.abs(layer.self_attn.k_proj.weight[start_idx:end_idx, ])) + torch.sum(torch.abs(layer.self_attn.k_proj.bias[start_idx:end_idx])))
norm_loss_layer += scaling * (torch.sum(torch.abs(layer.self_attn.v_proj.weight[start_idx:end_idx, ])) + torch.sum(torch.abs(layer.self_attn.v_proj.bias[start_idx:end_idx])))
norm_loss += norm_loss_layer
return norm_loss
def _get_adaptive_ffn_loss(self):
ffn_scale_factor = float(self.args.ffn_reg_scale_factor)
filter_loss = 0
for layer in self.encoder.sentence_encoder.layers:
filter_loss += torch.sum(torch.abs(layer.fc1.weight * ffn_scale_factor)) + torch.sum(torch.abs(layer.fc2.weight * ffn_scale_factor))
filter_loss += torch.sum(torch.abs(layer.fc1.bias * ffn_scale_factor)) + torch.sum(torch.abs(layer.fc2.bias * ffn_scale_factor))
return filter_loss
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()

View File

@ -305,6 +305,21 @@ class RobertaTest(unittest.TestCase):
# Incremental vs non-incremental
self.assertTensorEqual(ro_dec_inc[i][:, 0], ro_dec[:, i])
@cpu_gpu
def test_regularize_for_adaprune_in_roberta(self, device: str):
_, model = get_toy_model(
device=device,
architecture="roberta_base",
mha_reg_scale_factor=0.000375,
ffn_reg_scale_factor=0.000375,
)
sample = mk_sample("en", device, batch_size=1)
task_loss, _ = model.forward(**sample["net_input"])
head_loss = model._get_adaptive_head_loss()
ffn_loss = model._get_adaptive_ffn_loss()
loss = task_loss.sum() + head_loss + ffn_loss
loss.backward()
def params(model, name):
if "." not in name: