Filter padding properly in LabelSmoothedCrossEntropyCriterion (#229)

This commit is contained in:
Myle Ott 2018-03-04 15:44:05 -05:00 committed by Sergey Edunov
parent 5f29d1236c
commit e73fddf453
3 changed files with 108 additions and 66 deletions

View File

@ -7,7 +7,6 @@
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from fairseq import utils
@ -15,41 +14,6 @@ from fairseq import utils
from . import FairseqCriterion, register_criterion
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
norm = grad_input.size(-1)
if weights is not None:
if isinstance(grad_input, Variable) and not isinstance(weights, Variable):
weights = Variable(weights, requires_grad=False)
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
if padding_idx is not None:
norm -= 1 if weights is None else weights[padding_idx]
grad_input.select(grad_input.dim() - 1, padding_idx).fill_(0)
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
if reduce:
return grad_input.view(-1).dot(input.view(-1))
else:
return grad_input * input
@staticmethod
def backward(ctx, grad):
grad_input = ctx.grad_input
if not isinstance(grad_input, torch.autograd.Variable):
grad_input = utils.volatile_variable(grad_input)
return grad_input * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
target = sample['target'].unsqueeze(-1)
non_pad_mask = target.ne(self.padding_idx)
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,

View File

@ -4,31 +4,98 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import copy
import unittest
import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
from torch.autograd import Variable
from fairseq import utils
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
torch.set_default_tensor_type('torch.DoubleTensor')
import tests.utils as test_utils
class TestLabelSmoothing(unittest.TestCase):
def test_label_smoothing(self):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))
weights = torch.ones(5)
weights[2] = 0
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, weights), (input, target)))
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, None), (input, target)))
def setUp(self):
# build dictionary
self.d = test_utils.dummy_dictionary(3)
vocab = len(self.d)
self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
self.assertEqual(self.d.pad(), 1)
self.assertEqual(self.d.eos(), 2)
self.assertEqual(self.d.unk(), 3)
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6
# build dataset
self.data = [
# the first batch item has padding
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])},
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])},
]
self.sample = next(test_utils.dummy_dataloader(self.data))
# build model
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.probs = torch.FloatTensor([
# pad eos unk w1 w2 w3
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6)
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
# create a new sample with just a single batch item so that there's
# no padding
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d)
loss1, _, _ = crit(model1, sample1)
return loss1
loss1 = get_one_no_padding(0)
loss2 = get_one_no_padding(1)
self.assertAlmostEqual(loss, loss1 + loss2)
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-6)
if __name__ == '__main__':

View File

@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder):
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
assert hasattr(args, 'beam_probs')
assert hasattr(args, 'beam_probs') or hasattr(args, 'probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
@ -116,14 +116,19 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
steps = list(range(tgt_len))
# define output in terms of raw probs
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, src_len, tgt_len)