some refactoring of VW, towards customizable loss calculation in training

This commit is contained in:
Ales Tamchyna 2015-03-04 14:26:26 +01:00
parent f9ec387a5b
commit 7601618477
2 changed files with 153 additions and 68 deletions

View File

@ -59,6 +59,37 @@ private:
int m_min, m_max; int m_min, m_max;
}; };
/**
* Calculation of training loss.
*/
class TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
};
/**
* Basic 1/0 training loss.
*/
class TrainingLossBasic : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/**
* BLEU2+1 training loss.
*/
class TrainingLossBLEU : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/** /**
* VW thread-specific data about target sentence. * VW thread-specific data about target sentence.
*/ */
@ -96,6 +127,7 @@ struct VWTargetSentence {
}; };
typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier; typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;
typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence; typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
class VW : public StatelessFeatureFunction, public TLSTargetSentence class VW : public StatelessFeatureFunction, public TLSTargetSentence
@ -116,6 +148,10 @@ public:
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n"); VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n");
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer(); m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
} }
if (! m_trainingLoss) {
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
}
} }
virtual ~VW() { virtual ~VW() {
@ -150,76 +186,107 @@ public:
VERBOSE(2, "VW :: Evaluating translation options\n"); VERBOSE(2, "VW :: Evaluating translation options\n");
const std::vector<VWFeatureBase*>& sourceFeatures = VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription()); // which feature functions do we use (on the source and target side)
const std::vector<VWFeatureBase*>& sourceFeatures =
VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
const std::vector<VWFeatureBase*>& targetFeatures =
VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange(); const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange();
const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath(); const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath();
// optionally update translation options using leave-one-out
std::vector<bool> keep = (m_leaveOneOut.size() > 0)
? LeaveOneOut(translationOptionList)
: std::vector<bool>(translationOptionList.size(), true);
std::vector<float> losses(translationOptionList.size());
std::vector<float>::iterator iterLoss;
TranslationOptionList::const_iterator iterTransOpt;
std::vector<bool>::const_iterator iterKeep;
if (m_train) { if (m_train) {
// check which translation options are correct in advance //
bool seenCorrect = false; // extract features for training the classifier (only call this when using vwtrainer, not in Moses!)
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; //
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss, ++iterKeep) {
bool isCorrect = IsCorrectTranslationOption(**iterTransOpt); // find which topts are correct
*iterLoss = isCorrect ? 0.0 : 1.0; std::vector<bool> correct(translationOptionList.size());
if (isCorrect && *iterKeep) seenCorrect = true; for (size_t i = 0; i < translationOptionList.size(); i++)
correct[i] = IsCorrectTranslationOption(* translationOptionList.Get(i));
// optionally update translation options using leave-one-out
std::vector<bool> keep = (m_leaveOneOut.size() > 0)
? LeaveOneOut(translationOptionList, correct)
: std::vector<bool>(translationOptionList.size(), true);
// check whether we (still) have some correct translation
int firstCorrect = -1;
for (size_t i = 0; i < translationOptionList.size(); i++) {
if (keep[i] && correct[i]) {
firstCorrect = i;
break;
}
} }
// do not train if there are no positive examples // do not train if there are no positive examples
if (! seenCorrect) { if (firstCorrect == -1) {
VERBOSE(2, "VW :: skipping topt collection, no correct translation for span\n"); VERBOSE(2, "VW :: skipping topt collection, no correct translation for span\n");
return; return;
} }
}
for(size_t i = 0; i < sourceFeatures.size(); ++i) // the first correct topt can be used by some loss functions
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier); const TargetPhrase &correctPhrase = translationOptionList.Get(firstCorrect)->GetTargetPhrase();
const std::vector<VWFeatureBase*>& targetFeatures = VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription()); // extract source side features
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; // go over topts, extract target side features and train the classifier
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) { for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
if (! *iterKeep) // this topt was discarded by leaving one out
continue; if (! keep[toptIdx])
continue;
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase(); // extract target-side features for each topt
for(size_t i = 0; i < targetFeatures.size(); ++i) const TargetPhrase &targetPhrase = translationOptionList.Get(toptIdx)->GetTargetPhrase();
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier); for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
if (! m_train) { float loss = (*m_trainingLoss)(targetPhrase, correctPhrase, correct[toptIdx]);
*iterLoss = classifier.Predict(MakeTargetLabel(targetPhrase));
} else { // train classifier on current example
classifier.Train(MakeTargetLabel(targetPhrase), *iterLoss); classifier.Train(MakeTargetLabel(targetPhrase), loss);
} }
} } else {
//
// predict using a trained classifier, use this in decoding (=at test time)
//
std::vector<float> losses(translationOptionList.size());
(*m_normalizer)(losses); // extract source side features
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ; for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) { const TranslationOption *topt = translationOptionList.Get(toptIdx);
if (! *iterKeep) const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
continue;
TranslationOption &transOpt = **iterTransOpt; // extract target-side features for each topt
for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
std::vector<float> newScores(m_numScoreComponents); // get classifier score
newScores[0] = FloorScore(TransformScore(*iterLoss)); losses[toptIdx] = classifier.Predict(MakeTargetLabel(targetPhrase));
}
ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown(); // normalize classifier scores to get a probability distribution
scoreBreakDown.PlusEquals(this, newScores); (*m_normalizer)(losses);
transOpt.UpdateScore(); // update scores of topts
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
TranslationOption *topt = *(translationOptionList.begin() + toptIdx);
std::vector<float> newScores(m_numScoreComponents);
newScores[0] = FloorScore(TransformScore(losses[toptIdx]));
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
scoreBreakDown.PlusEquals(this, newScores);
topt->UpdateScore();
}
} }
} }
@ -231,7 +298,6 @@ public:
ScoreComponentCollection* accumulator) const { ScoreComponentCollection* accumulator) const {
} }
void SetParameter(const std::string& key, const std::string& value) { void SetParameter(const std::string& key, const std::string& value) {
if (key == "train") { if (key == "train") {
m_train = Scan<bool>(value); m_train = Scan<bool>(value);
@ -241,10 +307,25 @@ public:
m_vwOptions = value; m_vwOptions = value;
} else if (key == "leave-one-out-from") { } else if (key == "leave-one-out-from") {
m_leaveOneOut = value; m_leaveOneOut = value;
} else if (key == "training-loss") {
// which type of loss to use for training
if (value == "basic") {
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
} else if (value == "bleu") {
m_trainingLoss = (TrainingLoss *) new TrainingLossBLEU();
} else {
UTIL_THROW2("Unknown training loss type:" << value);
}
} else if (key == "loss") { } else if (key == "loss") {
m_normalizer = value == "logistic" // which normalizer to use (theoretically depends on the loss function used for training the
? (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer() // classifier (squared/logistic/hinge/...), hence the name "loss"
: (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer(); if (value == "logistic") {
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
} else if (value == "squared") {
m_normalizer = (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer();
} else {
UTIL_THROW2("Unknown loss type:" << value);
}
} else { } else {
StatelessFeatureFunction::SetParameter(key, value); StatelessFeatureFunction::SetParameter(key, value);
} }
@ -255,12 +336,13 @@ public:
if (! m_train) if (! m_train)
return; return;
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, "This feature function requires the TabbedSentence input type"); UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput,
"This feature function requires the TabbedSentence input type");
const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source); const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source);
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, "TabbedSentence must contain target<tab>alignment"); UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2,
"TabbedSentence must contain target<tab>alignment");
// target sentence represented as a phrase // target sentence represented as a phrase
Phrase *target = new Phrase(); Phrase *target = new Phrase();
target->CreateFromString( target->CreateFromString(
@ -279,11 +361,6 @@ public:
targetSent.m_sentence = target; targetSent.m_sentence = target;
targetSent.m_alignment = alignment; targetSent.m_alignment = alignment;
//std::cerr << static_cast<const Phrase&>(tabbedSentence) << std::endl;
//std::cerr << *target << std::endl;
//std::cerr << *alignment << std::endl;
// pre-compute max- and min- aligned points for faster translation option checking // pre-compute max- and min- aligned points for faster translation option checking
targetSent.SetConstraints(source.GetSize()); targetSent.SetConstraints(source.GetSize());
} }
@ -328,7 +405,9 @@ private:
targetStart2 = i; targetStart2 = i;
int targetEnd2 = targetEnd; int targetEnd2 = targetEnd;
for(int i = targetEnd2; i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); ++i) for(int i = targetEnd2;
i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet();
++i)
targetEnd2 = i; targetEnd2 = i;
//std::cerr << "Longer: " << targetStart2 << " " << targetEnd2 << std::endl; //std::cerr << "Longer: " << targetStart2 << " " << targetEnd2 << std::endl;
@ -364,7 +443,7 @@ private:
return false; return false;
} }
std::vector<bool> LeaveOneOut(const TranslationOptionList &topts) const { std::vector<bool> LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const {
UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!"); UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!");
float sourceRawCount = 0.0; float sourceRawCount = 0.0;
@ -372,12 +451,14 @@ private:
std::vector<bool> keepOpt; std::vector<bool> keepOpt;
TranslationOptionList::const_iterator iterTransOpt; for (size_t i = 0; i < topts.size(); i++) {
for(iterTransOpt = topts.begin(); iterTransOpt != topts.end(); ++iterTransOpt) { TranslationOption *topt = *(topts.begin() + i);
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase(); const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
// extract raw counts from phrase-table property // extract raw counts from phrase-table property
const CountsPhraseProperty *property = static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts")); const CountsPhraseProperty *property =
static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts"));
if (! property) { if (! property) {
VERBOSE(1, "VW :: Counts not found for topt! Is this an OOV?\n"); VERBOSE(1, "VW :: Counts not found for topt! Is this an OOV?\n");
// keep all translation opts without updating, this is either OOV or bad usage... // keep all translation opts without updating, this is either OOV or bad usage...
@ -394,7 +475,7 @@ private:
} }
} }
float discount = IsCorrectTranslationOption(**iterTransOpt) ? ONE : 0.0; float discount = correct[i] ? ONE : 0.0;
float target = property->GetTargetMarginal() - discount; float target = property->GetTargetMarginal() - discount;
float joint = property->GetJointCount() - discount; float joint = property->GetJointCount() - discount;
if (discount != 0.0) VERBOSE(2, "VW :: leaving one out!\n"); if (discount != 0.0) VERBOSE(2, "VW :: leaving one out!\n");
@ -407,9 +488,9 @@ private:
scores[0] = TransformScore(joint / target); // P(f|e) scores[0] = TransformScore(joint / target); // P(f|e)
scores[2] = TransformScore(joint / sourceRawCount); // P(e|f) scores[2] = TransformScore(joint / sourceRawCount); // P(e|f)
ScoreComponentCollection &scoreBreakDown = (*iterTransOpt)->GetScoreBreakdown(); ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
scoreBreakDown.Assign(feature, scores); scoreBreakDown.Assign(feature, scores);
(*iterTransOpt)->UpdateScore(); topt->UpdateScore();
keepOpt.push_back(true); keepOpt.push_back(true);
} else { } else {
// they only occurred together once, discard topt // they only occurred together once, discard topt
@ -425,6 +506,9 @@ private:
std::string m_modelPath; std::string m_modelPath;
std::string m_vwOptions; std::string m_vwOptions;
// calculator of training loss
TrainingLoss *m_trainingLoss = NULL;
// optionally contains feature name of a phrase table where we recompute scores with leaving one out // optionally contains feature name of a phrase table where we recompute scores with leaving one out
std::string m_leaveOneOut; std::string m_leaveOneOut;

View File

@ -52,6 +52,7 @@ public:
const TranslationOption *Get(size_t ind) const { const TranslationOption *Get(size_t ind) const {
return m_coll.at(ind); return m_coll.at(ind);
} }
void Remove( size_t ind ) { void Remove( size_t ind ) {
UTIL_THROW_IF2(ind >= m_coll.size(), UTIL_THROW_IF2(ind >= m_coll.size(),
"Out of bound index " << ind); "Out of bound index " << ind);