Rename LabelSmoothedCrossEntropy to LabelSmoothedNLLLoss

This commit is contained in:
Myle Ott 2017-11-07 17:12:23 -05:00
parent b1dfd39eb2
commit e1f49695ee
2 changed files with 4 additions and 4 deletions

View File

@ -14,7 +14,7 @@ import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function):
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights):
@ -59,7 +59,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0],

View File

@ -8,7 +8,7 @@
import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy()
criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))