diff --git a/moses/FF/VW/VW.h b/moses/FF/VW/VW.h index 8b7330440..1033fabd7 100644 --- a/moses/FF/VW/VW.h +++ b/moses/FF/VW/VW.h @@ -59,6 +59,37 @@ private: int m_min, m_max; }; +/** + * Calculation of training loss. + */ +class TrainingLoss +{ +public: + virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0; +}; + +/** + * Basic 1/0 training loss. + */ +class TrainingLossBasic : public TrainingLoss +{ +public: + virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const { + return isCorrect ? 0.0 : 1.0; + } +}; + +/** + * BLEU2+1 training loss. + */ +class TrainingLossBLEU : public TrainingLoss +{ +public: + virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const { + return isCorrect ? 0.0 : 1.0; + } +}; + /** * VW thread-specific data about target sentence. */ @@ -96,6 +127,7 @@ struct VWTargetSentence { }; typedef ThreadLocalByFeatureStorage TLSClassifier; + typedef ThreadLocalByFeatureStorage TLSTargetSentence; class VW : public StatelessFeatureFunction, public TLSTargetSentence @@ -116,6 +148,10 @@ public: VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n"); m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer(); } + + if (! m_trainingLoss) { + m_trainingLoss = (TrainingLoss *) new TrainingLossBasic(); + } } virtual ~VW() { @@ -150,76 +186,107 @@ public: VERBOSE(2, "VW :: Evaluating translation options\n"); - const std::vector& sourceFeatures = VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription()); + // which feature functions do we use (on the source and target side) + const std::vector& sourceFeatures = + VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription()); + + const std::vector& targetFeatures = + VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription()); const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange(); const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath(); - // optionally update translation options using leave-one-out - std::vector keep = (m_leaveOneOut.size() > 0) - ? LeaveOneOut(translationOptionList) - : std::vector(translationOptionList.size(), true); - - std::vector losses(translationOptionList.size()); - std::vector::iterator iterLoss; - TranslationOptionList::const_iterator iterTransOpt; - std::vector::const_iterator iterKeep; - if (m_train) { - // check which translation options are correct in advance - bool seenCorrect = false; - for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; - iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss, ++iterKeep) { - bool isCorrect = IsCorrectTranslationOption(**iterTransOpt); - *iterLoss = isCorrect ? 0.0 : 1.0; - if (isCorrect && *iterKeep) seenCorrect = true; + // + // extract features for training the classifier (only call this when using vwtrainer, not in Moses!) + // + + // find which topts are correct + std::vector correct(translationOptionList.size()); + for (size_t i = 0; i < translationOptionList.size(); i++) + correct[i] = IsCorrectTranslationOption(* translationOptionList.Get(i)); + + // optionally update translation options using leave-one-out + std::vector keep = (m_leaveOneOut.size() > 0) + ? LeaveOneOut(translationOptionList, correct) + : std::vector(translationOptionList.size(), true); + + // check whether we (still) have some correct translation + int firstCorrect = -1; + for (size_t i = 0; i < translationOptionList.size(); i++) { + if (keep[i] && correct[i]) { + firstCorrect = i; + break; + } } // do not train if there are no positive examples - if (! seenCorrect) { + if (firstCorrect == -1) { VERBOSE(2, "VW :: skipping topt collection, no correct translation for span\n"); return; } - } - for(size_t i = 0; i < sourceFeatures.size(); ++i) - (*sourceFeatures[i])(input, inputPath, sourceRange, classifier); + // the first correct topt can be used by some loss functions + const TargetPhrase &correctPhrase = translationOptionList.Get(firstCorrect)->GetTargetPhrase(); - const std::vector& targetFeatures = VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription()); + // extract source side features + for(size_t i = 0; i < sourceFeatures.size(); ++i) + (*sourceFeatures[i])(input, inputPath, sourceRange, classifier); - for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; - iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) { + // go over topts, extract target side features and train the classifier + for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { - if (! *iterKeep) - continue; + // this topt was discarded by leaving one out + if (! keep[toptIdx]) + continue; - const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase(); - for(size_t i = 0; i < targetFeatures.size(); ++i) - (*targetFeatures[i])(input, inputPath, targetPhrase, classifier); + // extract target-side features for each topt + const TargetPhrase &targetPhrase = translationOptionList.Get(toptIdx)->GetTargetPhrase(); + for(size_t i = 0; i < targetFeatures.size(); ++i) + (*targetFeatures[i])(input, inputPath, targetPhrase, classifier); - if (! m_train) { - *iterLoss = classifier.Predict(MakeTargetLabel(targetPhrase)); - } else { - classifier.Train(MakeTargetLabel(targetPhrase), *iterLoss); + float loss = (*m_trainingLoss)(targetPhrase, correctPhrase, correct[toptIdx]); + + // train classifier on current example + classifier.Train(MakeTargetLabel(targetPhrase), loss); } - } + } else { + // + // predict using a trained classifier, use this in decoding (=at test time) + // + + std::vector losses(translationOptionList.size()); - (*m_normalizer)(losses); + // extract source side features + for(size_t i = 0; i < sourceFeatures.size(); ++i) + (*sourceFeatures[i])(input, inputPath, sourceRange, classifier); - for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; - iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) { - if (! *iterKeep) - continue; + for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { + const TranslationOption *topt = translationOptionList.Get(toptIdx); + const TargetPhrase &targetPhrase = topt->GetTargetPhrase(); - TranslationOption &transOpt = **iterTransOpt; + // extract target-side features for each topt + for(size_t i = 0; i < targetFeatures.size(); ++i) + (*targetFeatures[i])(input, inputPath, targetPhrase, classifier); - std::vector newScores(m_numScoreComponents); - newScores[0] = FloorScore(TransformScore(*iterLoss)); + // get classifier score + losses[toptIdx] = classifier.Predict(MakeTargetLabel(targetPhrase)); + } - ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown(); - scoreBreakDown.PlusEquals(this, newScores); + // normalize classifier scores to get a probability distribution + (*m_normalizer)(losses); - transOpt.UpdateScore(); + // update scores of topts + for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) { + TranslationOption *topt = *(translationOptionList.begin() + toptIdx); + std::vector newScores(m_numScoreComponents); + newScores[0] = FloorScore(TransformScore(losses[toptIdx])); + + ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown(); + scoreBreakDown.PlusEquals(this, newScores); + + topt->UpdateScore(); + } } } @@ -231,7 +298,6 @@ public: ScoreComponentCollection* accumulator) const { } - void SetParameter(const std::string& key, const std::string& value) { if (key == "train") { m_train = Scan(value); @@ -241,10 +307,25 @@ public: m_vwOptions = value; } else if (key == "leave-one-out-from") { m_leaveOneOut = value; + } else if (key == "training-loss") { + // which type of loss to use for training + if (value == "basic") { + m_trainingLoss = (TrainingLoss *) new TrainingLossBasic(); + } else if (value == "bleu") { + m_trainingLoss = (TrainingLoss *) new TrainingLossBLEU(); + } else { + UTIL_THROW2("Unknown training loss type:" << value); + } } else if (key == "loss") { - m_normalizer = value == "logistic" - ? (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer() - : (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer(); + // which normalizer to use (theoretically depends on the loss function used for training the + // classifier (squared/logistic/hinge/...), hence the name "loss" + if (value == "logistic") { + m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer(); + } else if (value == "squared") { + m_normalizer = (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer(); + } else { + UTIL_THROW2("Unknown loss type:" << value); + } } else { StatelessFeatureFunction::SetParameter(key, value); } @@ -255,12 +336,13 @@ public: if (! m_train) return; - UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, "This feature function requires the TabbedSentence input type"); + UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, + "This feature function requires the TabbedSentence input type"); const TabbedSentence& tabbedSentence = static_cast(source); - UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, "TabbedSentence must contain targetalignment"); - - + UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, + "TabbedSentence must contain targetalignment"); + // target sentence represented as a phrase Phrase *target = new Phrase(); target->CreateFromString( @@ -279,11 +361,6 @@ public: targetSent.m_sentence = target; targetSent.m_alignment = alignment; - //std::cerr << static_cast(tabbedSentence) << std::endl; - //std::cerr << *target << std::endl; - //std::cerr << *alignment << std::endl; - - // pre-compute max- and min- aligned points for faster translation option checking targetSent.SetConstraints(source.GetSize()); } @@ -328,7 +405,9 @@ private: targetStart2 = i; int targetEnd2 = targetEnd; - for(int i = targetEnd2; i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); ++i) + for(int i = targetEnd2; + i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); + ++i) targetEnd2 = i; //std::cerr << "Longer: " << targetStart2 << " " << targetEnd2 << std::endl; @@ -364,7 +443,7 @@ private: return false; } - std::vector LeaveOneOut(const TranslationOptionList &topts) const { + std::vector LeaveOneOut(const TranslationOptionList &topts, const std::vector &correct) const { UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!"); float sourceRawCount = 0.0; @@ -372,12 +451,14 @@ private: std::vector keepOpt; - TranslationOptionList::const_iterator iterTransOpt; - for(iterTransOpt = topts.begin(); iterTransOpt != topts.end(); ++iterTransOpt) { - const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase(); + for (size_t i = 0; i < topts.size(); i++) { + TranslationOption *topt = *(topts.begin() + i); + const TargetPhrase &targetPhrase = topt->GetTargetPhrase(); // extract raw counts from phrase-table property - const CountsPhraseProperty *property = static_cast(targetPhrase.GetProperty("Counts")); + const CountsPhraseProperty *property = + static_cast(targetPhrase.GetProperty("Counts")); + if (! property) { VERBOSE(1, "VW :: Counts not found for topt! Is this an OOV?\n"); // keep all translation opts without updating, this is either OOV or bad usage... @@ -394,7 +475,7 @@ private: } } - float discount = IsCorrectTranslationOption(**iterTransOpt) ? ONE : 0.0; + float discount = correct[i] ? ONE : 0.0; float target = property->GetTargetMarginal() - discount; float joint = property->GetJointCount() - discount; if (discount != 0.0) VERBOSE(2, "VW :: leaving one out!\n"); @@ -407,9 +488,9 @@ private: scores[0] = TransformScore(joint / target); // P(f|e) scores[2] = TransformScore(joint / sourceRawCount); // P(e|f) - ScoreComponentCollection &scoreBreakDown = (*iterTransOpt)->GetScoreBreakdown(); + ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown(); scoreBreakDown.Assign(feature, scores); - (*iterTransOpt)->UpdateScore(); + topt->UpdateScore(); keepOpt.push_back(true); } else { // they only occurred together once, discard topt @@ -425,6 +506,9 @@ private: std::string m_modelPath; std::string m_vwOptions; + // calculator of training loss + TrainingLoss *m_trainingLoss = NULL; + // optionally contains feature name of a phrase table where we recompute scores with leaving one out std::string m_leaveOneOut; diff --git a/moses/TranslationOptionList.h b/moses/TranslationOptionList.h index 54ce94bb9..119a308f7 100644 --- a/moses/TranslationOptionList.h +++ b/moses/TranslationOptionList.h @@ -52,6 +52,7 @@ public: const TranslationOption *Get(size_t ind) const { return m_coll.at(ind); } + void Remove( size_t ind ) { UTIL_THROW_IF2(ind >= m_coll.size(), "Out of bound index " << ind);