mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-08-16 15:00:33 +03:00
some refactoring of VW, towards customizable loss calculation in training
This commit is contained in:
parent
f9ec387a5b
commit
7601618477
200
moses/FF/VW/VW.h
200
moses/FF/VW/VW.h
@ -59,6 +59,37 @@ private:
|
||||
int m_min, m_max;
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculation of training loss.
|
||||
*/
|
||||
class TrainingLoss
|
||||
{
|
||||
public:
|
||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* Basic 1/0 training loss.
|
||||
*/
|
||||
class TrainingLossBasic : public TrainingLoss
|
||||
{
|
||||
public:
|
||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
||||
return isCorrect ? 0.0 : 1.0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* BLEU2+1 training loss.
|
||||
*/
|
||||
class TrainingLossBLEU : public TrainingLoss
|
||||
{
|
||||
public:
|
||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
||||
return isCorrect ? 0.0 : 1.0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* VW thread-specific data about target sentence.
|
||||
*/
|
||||
@ -96,6 +127,7 @@ struct VWTargetSentence {
|
||||
};
|
||||
|
||||
typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;
|
||||
|
||||
typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
|
||||
|
||||
class VW : public StatelessFeatureFunction, public TLSTargetSentence
|
||||
@ -116,6 +148,10 @@ public:
|
||||
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n");
|
||||
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
|
||||
}
|
||||
|
||||
if (! m_trainingLoss) {
|
||||
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~VW() {
|
||||
@ -150,76 +186,107 @@ public:
|
||||
|
||||
VERBOSE(2, "VW :: Evaluating translation options\n");
|
||||
|
||||
const std::vector<VWFeatureBase*>& sourceFeatures = VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
|
||||
// which feature functions do we use (on the source and target side)
|
||||
const std::vector<VWFeatureBase*>& sourceFeatures =
|
||||
VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
|
||||
|
||||
const std::vector<VWFeatureBase*>& targetFeatures =
|
||||
VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
|
||||
|
||||
const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange();
|
||||
const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath();
|
||||
|
||||
if (m_train) {
|
||||
//
|
||||
// extract features for training the classifier (only call this when using vwtrainer, not in Moses!)
|
||||
//
|
||||
|
||||
// find which topts are correct
|
||||
std::vector<bool> correct(translationOptionList.size());
|
||||
for (size_t i = 0; i < translationOptionList.size(); i++)
|
||||
correct[i] = IsCorrectTranslationOption(* translationOptionList.Get(i));
|
||||
|
||||
// optionally update translation options using leave-one-out
|
||||
std::vector<bool> keep = (m_leaveOneOut.size() > 0)
|
||||
? LeaveOneOut(translationOptionList)
|
||||
? LeaveOneOut(translationOptionList, correct)
|
||||
: std::vector<bool>(translationOptionList.size(), true);
|
||||
|
||||
std::vector<float> losses(translationOptionList.size());
|
||||
std::vector<float>::iterator iterLoss;
|
||||
TranslationOptionList::const_iterator iterTransOpt;
|
||||
std::vector<bool>::const_iterator iterKeep;
|
||||
|
||||
if (m_train) {
|
||||
// check which translation options are correct in advance
|
||||
bool seenCorrect = false;
|
||||
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
|
||||
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss, ++iterKeep) {
|
||||
bool isCorrect = IsCorrectTranslationOption(**iterTransOpt);
|
||||
*iterLoss = isCorrect ? 0.0 : 1.0;
|
||||
if (isCorrect && *iterKeep) seenCorrect = true;
|
||||
// check whether we (still) have some correct translation
|
||||
int firstCorrect = -1;
|
||||
for (size_t i = 0; i < translationOptionList.size(); i++) {
|
||||
if (keep[i] && correct[i]) {
|
||||
firstCorrect = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// do not train if there are no positive examples
|
||||
if (! seenCorrect) {
|
||||
if (firstCorrect == -1) {
|
||||
VERBOSE(2, "VW :: skipping topt collection, no correct translation for span\n");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// the first correct topt can be used by some loss functions
|
||||
const TargetPhrase &correctPhrase = translationOptionList.Get(firstCorrect)->GetTargetPhrase();
|
||||
|
||||
// extract source side features
|
||||
for(size_t i = 0; i < sourceFeatures.size(); ++i)
|
||||
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
|
||||
|
||||
const std::vector<VWFeatureBase*>& targetFeatures = VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
|
||||
// go over topts, extract target side features and train the classifier
|
||||
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
|
||||
|
||||
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
|
||||
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) {
|
||||
|
||||
if (! *iterKeep)
|
||||
// this topt was discarded by leaving one out
|
||||
if (! keep[toptIdx])
|
||||
continue;
|
||||
|
||||
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase();
|
||||
// extract target-side features for each topt
|
||||
const TargetPhrase &targetPhrase = translationOptionList.Get(toptIdx)->GetTargetPhrase();
|
||||
for(size_t i = 0; i < targetFeatures.size(); ++i)
|
||||
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
|
||||
|
||||
if (! m_train) {
|
||||
*iterLoss = classifier.Predict(MakeTargetLabel(targetPhrase));
|
||||
} else {
|
||||
classifier.Train(MakeTargetLabel(targetPhrase), *iterLoss);
|
||||
float loss = (*m_trainingLoss)(targetPhrase, correctPhrase, correct[toptIdx]);
|
||||
|
||||
// train classifier on current example
|
||||
classifier.Train(MakeTargetLabel(targetPhrase), loss);
|
||||
}
|
||||
} else {
|
||||
//
|
||||
// predict using a trained classifier, use this in decoding (=at test time)
|
||||
//
|
||||
|
||||
std::vector<float> losses(translationOptionList.size());
|
||||
|
||||
// extract source side features
|
||||
for(size_t i = 0; i < sourceFeatures.size(); ++i)
|
||||
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
|
||||
|
||||
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
|
||||
const TranslationOption *topt = translationOptionList.Get(toptIdx);
|
||||
const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
|
||||
|
||||
// extract target-side features for each topt
|
||||
for(size_t i = 0; i < targetFeatures.size(); ++i)
|
||||
(*targetFeatures[i])(input, inputPath, targetPhrase, classifier);
|
||||
|
||||
// get classifier score
|
||||
losses[toptIdx] = classifier.Predict(MakeTargetLabel(targetPhrase));
|
||||
}
|
||||
|
||||
// normalize classifier scores to get a probability distribution
|
||||
(*m_normalizer)(losses);
|
||||
|
||||
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin(), iterKeep = keep.begin() ;
|
||||
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) {
|
||||
if (! *iterKeep)
|
||||
continue;
|
||||
|
||||
TranslationOption &transOpt = **iterTransOpt;
|
||||
|
||||
// update scores of topts
|
||||
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {
|
||||
TranslationOption *topt = *(translationOptionList.begin() + toptIdx);
|
||||
std::vector<float> newScores(m_numScoreComponents);
|
||||
newScores[0] = FloorScore(TransformScore(*iterLoss));
|
||||
newScores[0] = FloorScore(TransformScore(losses[toptIdx]));
|
||||
|
||||
ScoreComponentCollection &scoreBreakDown = transOpt.GetScoreBreakdown();
|
||||
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
|
||||
scoreBreakDown.PlusEquals(this, newScores);
|
||||
|
||||
transOpt.UpdateScore();
|
||||
topt->UpdateScore();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,7 +298,6 @@ public:
|
||||
ScoreComponentCollection* accumulator) const {
|
||||
}
|
||||
|
||||
|
||||
void SetParameter(const std::string& key, const std::string& value) {
|
||||
if (key == "train") {
|
||||
m_train = Scan<bool>(value);
|
||||
@ -241,10 +307,25 @@ public:
|
||||
m_vwOptions = value;
|
||||
} else if (key == "leave-one-out-from") {
|
||||
m_leaveOneOut = value;
|
||||
} else if (key == "training-loss") {
|
||||
// which type of loss to use for training
|
||||
if (value == "basic") {
|
||||
m_trainingLoss = (TrainingLoss *) new TrainingLossBasic();
|
||||
} else if (value == "bleu") {
|
||||
m_trainingLoss = (TrainingLoss *) new TrainingLossBLEU();
|
||||
} else {
|
||||
UTIL_THROW2("Unknown training loss type:" << value);
|
||||
}
|
||||
} else if (key == "loss") {
|
||||
m_normalizer = value == "logistic"
|
||||
? (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer()
|
||||
: (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer();
|
||||
// which normalizer to use (theoretically depends on the loss function used for training the
|
||||
// classifier (squared/logistic/hinge/...), hence the name "loss"
|
||||
if (value == "logistic") {
|
||||
m_normalizer = (Discriminative::Normalizer *) new Discriminative::LogisticLossNormalizer();
|
||||
} else if (value == "squared") {
|
||||
m_normalizer = (Discriminative::Normalizer *) new Discriminative::SquaredLossNormalizer();
|
||||
} else {
|
||||
UTIL_THROW2("Unknown loss type:" << value);
|
||||
}
|
||||
} else {
|
||||
StatelessFeatureFunction::SetParameter(key, value);
|
||||
}
|
||||
@ -255,11 +336,12 @@ public:
|
||||
if (! m_train)
|
||||
return;
|
||||
|
||||
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput, "This feature function requires the TabbedSentence input type");
|
||||
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput,
|
||||
"This feature function requires the TabbedSentence input type");
|
||||
|
||||
const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source);
|
||||
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, "TabbedSentence must contain target<tab>alignment");
|
||||
|
||||
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2,
|
||||
"TabbedSentence must contain target<tab>alignment");
|
||||
|
||||
// target sentence represented as a phrase
|
||||
Phrase *target = new Phrase();
|
||||
@ -279,11 +361,6 @@ public:
|
||||
targetSent.m_sentence = target;
|
||||
targetSent.m_alignment = alignment;
|
||||
|
||||
//std::cerr << static_cast<const Phrase&>(tabbedSentence) << std::endl;
|
||||
//std::cerr << *target << std::endl;
|
||||
//std::cerr << *alignment << std::endl;
|
||||
|
||||
|
||||
// pre-compute max- and min- aligned points for faster translation option checking
|
||||
targetSent.SetConstraints(source.GetSize());
|
||||
}
|
||||
@ -328,7 +405,9 @@ private:
|
||||
targetStart2 = i;
|
||||
|
||||
int targetEnd2 = targetEnd;
|
||||
for(int i = targetEnd2; i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet(); ++i)
|
||||
for(int i = targetEnd2;
|
||||
i < targetSentence.m_sentence->GetSize() && !targetSentence.m_targetConstraints[i].IsSet();
|
||||
++i)
|
||||
targetEnd2 = i;
|
||||
|
||||
//std::cerr << "Longer: " << targetStart2 << " " << targetEnd2 << std::endl;
|
||||
@ -364,7 +443,7 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<bool> LeaveOneOut(const TranslationOptionList &topts) const {
|
||||
std::vector<bool> LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const {
|
||||
UTIL_THROW_IF2(m_leaveOneOut.size() == 0 || ! m_train, "LeaveOneOut called in wrong setting!");
|
||||
|
||||
float sourceRawCount = 0.0;
|
||||
@ -372,12 +451,14 @@ private:
|
||||
|
||||
std::vector<bool> keepOpt;
|
||||
|
||||
TranslationOptionList::const_iterator iterTransOpt;
|
||||
for(iterTransOpt = topts.begin(); iterTransOpt != topts.end(); ++iterTransOpt) {
|
||||
const TargetPhrase &targetPhrase = (*iterTransOpt)->GetTargetPhrase();
|
||||
for (size_t i = 0; i < topts.size(); i++) {
|
||||
TranslationOption *topt = *(topts.begin() + i);
|
||||
const TargetPhrase &targetPhrase = topt->GetTargetPhrase();
|
||||
|
||||
// extract raw counts from phrase-table property
|
||||
const CountsPhraseProperty *property = static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts"));
|
||||
const CountsPhraseProperty *property =
|
||||
static_cast<const CountsPhraseProperty *>(targetPhrase.GetProperty("Counts"));
|
||||
|
||||
if (! property) {
|
||||
VERBOSE(1, "VW :: Counts not found for topt! Is this an OOV?\n");
|
||||
// keep all translation opts without updating, this is either OOV or bad usage...
|
||||
@ -394,7 +475,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
float discount = IsCorrectTranslationOption(**iterTransOpt) ? ONE : 0.0;
|
||||
float discount = correct[i] ? ONE : 0.0;
|
||||
float target = property->GetTargetMarginal() - discount;
|
||||
float joint = property->GetJointCount() - discount;
|
||||
if (discount != 0.0) VERBOSE(2, "VW :: leaving one out!\n");
|
||||
@ -407,9 +488,9 @@ private:
|
||||
scores[0] = TransformScore(joint / target); // P(f|e)
|
||||
scores[2] = TransformScore(joint / sourceRawCount); // P(e|f)
|
||||
|
||||
ScoreComponentCollection &scoreBreakDown = (*iterTransOpt)->GetScoreBreakdown();
|
||||
ScoreComponentCollection &scoreBreakDown = topt->GetScoreBreakdown();
|
||||
scoreBreakDown.Assign(feature, scores);
|
||||
(*iterTransOpt)->UpdateScore();
|
||||
topt->UpdateScore();
|
||||
keepOpt.push_back(true);
|
||||
} else {
|
||||
// they only occurred together once, discard topt
|
||||
@ -425,6 +506,9 @@ private:
|
||||
std::string m_modelPath;
|
||||
std::string m_vwOptions;
|
||||
|
||||
// calculator of training loss
|
||||
TrainingLoss *m_trainingLoss = NULL;
|
||||
|
||||
// optionally contains feature name of a phrase table where we recompute scores with leaving one out
|
||||
std::string m_leaveOneOut;
|
||||
|
||||
|
@ -52,6 +52,7 @@ public:
|
||||
const TranslationOption *Get(size_t ind) const {
|
||||
return m_coll.at(ind);
|
||||
}
|
||||
|
||||
void Remove( size_t ind ) {
|
||||
UTIL_THROW_IF2(ind >= m_coll.size(),
|
||||
"Out of bound index " << ind);
|
||||
|
Loading…
Reference in New Issue
Block a user