From c76cb6dfb93b531369a0a7593227b31c3b99c0a3 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Mon, 19 Oct 2020 20:15:47 -0700 Subject: [PATCH] composite criterion should still use legacy criterion as it will break with subsequent diff Summary: see title Reviewed By: myleott Differential Revision: D24393903 fbshipit-source-id: 4b972b8150c7228fb32977675c6c60b13d5194d0 --- fairseq/criterions/composite_loss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py index 65341c2d3..98e835fa6 100644 --- a/fairseq/criterions/composite_loss.py +++ b/fairseq/criterions/composite_loss.py @@ -4,18 +4,18 @@ # LICENSE file in the root directory of this source tree. from fairseq import utils -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import LegacyFairseqCriterion, register_criterion from torch import nn @register_criterion("composite_loss") -class CompositeLoss(FairseqCriterion): +class CompositeLoss(LegacyFairseqCriterion): """This is a composite loss that, given a list of model outputs and a list of targets, computes an average of losses for each output-target pair""" - def __init__(self, task, underlying_criterion): - super().__init__(task) - self.underlying_criterion = underlying_criterion + def __init__(self, args, task): + super().__init__(args, task) + self.underlying_criterion = args.underlying_criterion @staticmethod def add_args(parser): @@ -60,9 +60,9 @@ class CompositeLoss(FairseqCriterion): def decoder(self): return self.model.decoder - class _CompositeLoss(FairseqCriterion): - def __init__(self, task, underlying_criterion): - super().__init__(task) + class _CompositeLoss(LegacyFairseqCriterion): + def __init__(self, args, task, underlying_criterion): + super().__init__(args, task) self.underlying_criterion = underlying_criterion def forward(self, model, sample, reduce=True): @@ -97,4 +97,4 @@ class CompositeLoss(FairseqCriterion): def reduce_metrics(logging_outputs) -> None: underlying_criterion.__class__.reduce_metrics(logging_outputs) - return _CompositeLoss(task, underlying_criterion) + return _CompositeLoss(args, task, underlying_criterion)