some refactoring of VW, towards customizable loss calculation in training

This commit is contained in:
Ales Tamchyna 2015-03-04 14:26:26 +01:00
parent f9ec387a5b
commit 7601618477
2 changed files with 153 additions and 68 deletions

View File

@ -59,6 +59,37 @@ private:
int m_min, m_max;
};
/**
* Calculation of training loss.
*/
class TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
};
/**
* Basic 1/0 training loss.
*/
class TrainingLossBasic : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/**
* BLEU2+1 training loss.
*/
class TrainingLossBLEU : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/**
* VW thread-specific data about target sentence.
*/
@ -96,6 +127,7 @@ struct VWTargetSentence {
};
typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;
typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
class VW : public StatelessFeatureFunction, public TLSTargetSentence
@ -116,6 +148,10 @@ public:
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n");
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
}
if (! m_trainingLoss) {
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
}
}
virtual ~VW() {
@ -150,76 +186,107 @@ public:
VERBOSE(2, "VW :: Evaluating translation options\n");
const std::vector<VWFeatureBase*>& sourceFeatures = VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
// which feature functions do we use (on the source and target side)
const std::vector<VWFeatureBase*>& sourceFeatures =
VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
const std::vector<VWFeatureBase*>& targetFeatures =
VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange();
const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath();
// optionally update translation options using leave-one-out
std::vector<bool> keep = (m_leaveOneOut.size() > 0)
? LeaveOneOut(translationOptionList)
: std::vector<bool>(translationOptionList.size(), true);
std::vector<float> losses(translationOptionList.size());
std::vector<float>::iterator iterLoss;
TranslationOptionList::const_iterator iterTransOpt;
std::vector<bool>::const_iterator iterKeep;
if (m_train) {
// check which translation options are correct in advance
bool seenCorrect = false;
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss, ++iterKeep) {
bool isCorrect = IsCorrectTranslationOption(**iterTransOpt);
*iterLoss = isCorrect ? 0.0 : 1.0;
if (isCorrect && *iterKeep) seenCorrect = true;
//
// extract features for training the classifier (only call this when using vwtrainer, not in Moses!)
//
// find which topts are correct
std::vector<bool> correct(translationOptionList.size());
for (size_t i = 0; i < translationOptionList.size(); i++)
correct[i] = IsCorrectTranslationOption(* translationOptionList.Get(i));
// optionally update translation options using leave-one-out
std::vector<bool> keep = (m_leaveOneOut.size() > 0)
? LeaveOneOut(translationOptionList, correct)
: std::vector<bool>(translationOptionList.size(), true);
// check whether we (still) have some correct translation
int firstCorrect = -1;
for (size_t i = 0; i < translationOptionList.size(); i++) {
if (keep[i] && correct[i]) {
firstCorrect = i;
break;
}
}
// do not train if there are no positive examples
if (! seenCorrect) {
if (firstCorrect == -1) {
VERBOSE(2, "VW :: skipping topt collection, no correct translation for span\n");
return;
}
}
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
// the first correct topt can be used by some loss functions
const TargetPhrase &correctPhrase = translationOptionList.Get(firstCorrect)->GetTargetPhrase();
const std::vector<VWFeatureBase*>& targetFeatures = VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
// extract source side features
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) {
// go over topts, extract target side features and train the classifier
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
if (! *iterKeep)
continue;
// this topt was discarded by leaving one out
if (! keep[toptIdx])
continue;
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase();
for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
// extract target-side features for each topt
const TargetPhrase &targetPhrase = translationOptionList.Get(toptIdx)->GetTargetPhrase();
for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
if (! m_train) {
*iterLoss = classifier.Predict(MakeTargetLabel(targetPhrase));
} else {
classifier.Train(MakeTargetLabel(targetPhrase), *iterLoss);
float loss = (*m_trainingLoss)(targetPhrase, correctPhrase, correct[toptIdx]);
// train classifier on current example
classifier.Train(MakeTargetLabel(targetPhrase), loss);
}
}
} else {
//
// predict using a trained classifier, use this in decoding (=at test time)
//
std::vector<float> losses(translationOptionList.size());
(*m_normalizer)(losses);
// extract source side features
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) {
if (! *iterKeep)
continue;
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
const TranslationOption *topt = translationOptionList.Get(toptIdx);
const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
TranslationOption &transOpt = **iterTransOpt;
// extract target-side features for each topt
for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
std::vector<float> newScores(m_numScoreComponents);
newScores[0] = FloorScore(TransformScore(*iterLoss));
// get classifier score
losses[toptIdx] = classifier.Predict(MakeTargetLabel(targetPhrase));
}
ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown();
scoreBreakDown.PlusEquals(this, newScores);
// normalize classifier scores to get a probability distribution
(*m_normalizer)(losses);
transOpt.UpdateScore();
// update scores of topts
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
TranslationOption *topt = *(translationOptionList.begin() + toptIdx);
std::vector<float> newScores(m_numScoreComponents);
newScores[0] = FloorScore(TransformScore(losses[toptIdx]));
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
scoreBreakDown.PlusEquals(this, newScores);
topt->UpdateScore();
}
}
}
@ -231,7 +298,6 @@ public:
ScoreComponentCollection* accumulator) const {
}
void SetParameter(const std::string& key, const std::string& value) {
if (key == "train") {
m_train = Scan<bool>(value);
@ -241,10 +307,25 @@ public:
m_vwOptions = value;
} else if (key == "leave-one-out-from") {
m_leaveOneOut = value;
} else if (key == "training-loss") {
// which type of loss to use for training
if (value == "basic") {
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
} else if (value == "bleu") {
m_trainingLoss = (TrainingLoss *) new TrainingLossBLEU();
} else {
UTIL_THROW2("Unknown training loss type:" << value);
}
} else if (key == "loss") {
m_normalizer = value == "logistic"
? (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer()
: (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer();
// which normalizer to use (theoretically depends on the loss function used for training the
// classifier (squared/logistic/hinge/...), hence the name "loss"
if (value == "logistic") {
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
} else if (value == "squared") {
m_normalizer = (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer();
} else {
UTIL_THROW2("Unknown loss type:" << value);
}
} else {
StatelessFeatureFunction::SetParameter(key, value);
}
@ -255,12 +336,13 @@ public:
if (! m_train)
return;
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, "This feature function requires the TabbedSentence input type");
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput,
"This feature function requires the TabbedSentence input type");
const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source);
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, "TabbedSentence must contain target<tab>alignment");
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2,
"TabbedSentence must contain target<tab>alignment");
// target sentence represented as a phrase
Phrase *target = new Phrase();
target->CreateFromString(
@ -279,11 +361,6 @@ public:
targetSent.m_sentence = target;
targetSent.m_alignment = alignment;
//std::cerr << static_cast<const Phrase&>(tabbedSentence) << std::endl;
//std::cerr << *target << std::endl;
//std::cerr << *alignment << std::endl;
// pre-compute max- and min- aligned points for faster translation option checking
targetSent.SetConstraints(source.GetSize());
}
@ -328,7 +405,9 @@ private:
targetStart2 = i;
int targetEnd2 = targetEnd;
for(int i = targetEnd2; i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); ++i)
for(int i = targetEnd2;
i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet();
++i)
targetEnd2 = i;
//std::cerr << "Longer: " << targetStart2 << " " << targetEnd2 << std::endl;
@ -364,7 +443,7 @@ private:
return false;
}
std::vector<bool> LeaveOneOut(const TranslationOptionList &topts) const {
std::vector<bool> LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const {
UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!");
float sourceRawCount = 0.0;
@ -372,12 +451,14 @@ private:
std::vector<bool> keepOpt;
TranslationOptionList::const_iterator iterTransOpt;
for(iterTransOpt = topts.begin(); iterTransOpt != topts.end(); ++iterTransOpt) {
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase();
for (size_t i = 0; i < topts.size(); i++) {
TranslationOption *topt = *(topts.begin() + i);
const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
// extract raw counts from phrase-table property
const CountsPhraseProperty *property = static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts"));
const CountsPhraseProperty *property =
static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts"));
if (! property) {
VERBOSE(1, "VW :: Counts not found for topt! Is this an OOV?\n");
// keep all translation opts without updating, this is either OOV or bad usage...
@ -394,7 +475,7 @@ private:
}
}
float discount = IsCorrectTranslationOption(**iterTransOpt) ? ONE : 0.0;
float discount = correct[i] ? ONE : 0.0;
float target = property->GetTargetMarginal() - discount;
float joint = property->GetJointCount() - discount;
if (discount != 0.0) VERBOSE(2, "VW :: leaving one out!\n");
@ -407,9 +488,9 @@ private:
scores[0] = TransformScore(joint / target); // P(f|e)
scores[2] = TransformScore(joint / sourceRawCount); // P(e|f)
ScoreComponentCollection &scoreBreakDown = (*iterTransOpt)->GetScoreBreakdown();
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
scoreBreakDown.Assign(feature, scores);
(*iterTransOpt)->UpdateScore();
topt->UpdateScore();
keepOpt.push_back(true);
} else {
// they only occurred together once, discard topt
@ -425,6 +506,9 @@ private:
std::string m_modelPath;
std::string m_vwOptions;
// calculator of training loss
TrainingLoss *m_trainingLoss = NULL;
// optionally contains feature name of a phrase table where we recompute scores with leaving one out
std::string m_leaveOneOut;

View File

@ -52,6 +52,7 @@ public:
const TranslationOption *Get(size_t ind) const {
return m_coll.at(ind);
}
void Remove( size_t ind ) {
UTIL_THROW_IF2(ind >= m_coll.size(),
"Out of bound index " << ind);