composite criterion should still use legacy criterion as it will break with subsequent diff

Summary: see title

Reviewed By: myleott

Differential Revision: D24393903

fbshipit-source-id: 4b972b8150c7228fb32977675c6c60b13d5194d0
This commit is contained in:
Alexei Baevski 2020-10-19 20:15:47 -07:00 committed by Facebook GitHub Bot
parent de5c2cb35a
commit c76cb6dfb9

View File

@ -4,18 +4,18 @@
# LICENSE file in the root directory of this source tree.
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from torch import nn
@register_criterion("composite_loss")
class CompositeLoss(FairseqCriterion):
class CompositeLoss(LegacyFairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
def __init__(self, task, underlying_criterion):
super().__init__(task)
self.underlying_criterion = underlying_criterion
def __init__(self, args, task):
super().__init__(args, task)
self.underlying_criterion = args.underlying_criterion
@staticmethod
def add_args(parser):
@ -60,9 +60,9 @@ class CompositeLoss(FairseqCriterion):
def decoder(self):
return self.model.decoder
class _CompositeLoss(FairseqCriterion):
def __init__(self, task, underlying_criterion):
super().__init__(task)
class _CompositeLoss(LegacyFairseqCriterion):
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True):
@ -97,4 +97,4 @@ class CompositeLoss(FairseqCriterion):
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)
return _CompositeLoss(task, underlying_criterion)
return _CompositeLoss(args, task, underlying_criterion)