minor refactoring in VW feature

This commit is contained in:
Ales Tamchyna 2015-03-04 17:40:05 +01:00
parent 34649b74d3
commit ff1cae919b
2 changed files with 83 additions and 69 deletions

View File

@ -0,0 +1,82 @@
#pragma once
#include <set>
#include <cmath>
#include <string>
#include "moses/Util.h"
#include "moses/StaticData.h"
#include "moses/Phrase.h"
namespace Moses
{
/**
* Calculation of training loss for VW.
*/
class TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
};
/**
* Basic 1/0 training loss.
*/
class TrainingLossBasic : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/**
* BLEU2+1 training loss.
*/
class TrainingLossBLEU : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
std::multiset<std::string> refNgrams;
float precision = 1.0;
for (size_t size = 1; size <= BLEU_N; size++) {
for (int pos = 0; pos <= (int)correct.GetSize() - (int)size; pos++) {
refNgrams.insert(MakeNGram(correct, pos, pos + size));
}
int confirmed = 1; // we're BLEU+1
int total = 1;
for (int pos = 0; pos <= (int)candidate.GetSize() - (int)size; pos++) {
total++;
std::string ngram = MakeNGram(candidate, pos, pos + size);
std::multiset<std::string>::iterator it;
if ((it = refNgrams.find(ngram)) != refNgrams.end()) {
confirmed++;
refNgrams.erase(it);
}
}
precision *= (float)confirmed / total;
}
float brevityPenalty = exp((float)(1.0 - correct.GetSize()) / candidate.GetSize());
return 1.0 - brevityPenalty * pow(precision, (float)1.0 / BLEU_N);
}
private:
std::string MakeNGram(const TargetPhrase &phrase, size_t start, size_t end) const {
std::vector<std::string> words;
while (start != end) {
words.push_back(phrase.GetWord(start).GetString(StaticData::Instance().GetOutputFactorOrder(), false));
start++;
}
return Join(" ", words);
}
static const size_t BLEU_N = 2;
};
}

View File

@ -3,8 +3,6 @@
#include <string>
#include <map>
#include <limits>
#include <set>
#include <cmath>
#include "moses/FF/StatelessFeatureFunction.h"
#include "moses/PP/CountsPhraseProperty.h"
@ -21,6 +19,7 @@
#include "VWFeatureBase.h"
#include "TabbedSentence.h"
#include "ThreadLocalByFeatureStorage.h"
#include "TrainingLoss.h"
namespace Moses
{
@ -61,73 +60,6 @@ private:
int m_min, m_max;
};
/**
* Calculation of training loss.
*/
class TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
};
/**
* Basic 1/0 training loss.
*/
class TrainingLossBasic : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
return isCorrect ? 0.0 : 1.0;
}
};
/**
* BLEU2+1 training loss.
*/
class TrainingLossBLEU : public TrainingLoss
{
public:
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
std::multiset<std::string> refNgrams;
float precision = 1.0;
for (size_t size = 1; size <= BLEU_N; size++) {
for (int pos = 0; pos <= (int)correct.GetSize() - (int)size; pos++) {
refNgrams.insert(MakeNGram(correct, pos, pos + size));
}
int confirmed = 1; // we're BLEU+1
int total = 1;
for (int pos = 0; pos <= (int)candidate.GetSize() - (int)size; pos++) {
total++;
std::string ngram = MakeNGram(candidate, pos, pos + size);
std::multiset<std::string>::iterator it;
if ((it = refNgrams.find(ngram)) != refNgrams.end()) {
confirmed++;
refNgrams.erase(it);
}
}
precision *= (float)confirmed / total;
}
float brevityPenalty = exp((float)(1.0 - correct.GetSize()) / candidate.GetSize());
return 1.0 - brevityPenalty * pow(precision, (float)1.0 / BLEU_N);
}
private:
std::string MakeNGram(const TargetPhrase &phrase, size_t start, size_t end) const {
std::vector<std::string> words;
while (start != end) {
words.push_back(phrase.GetWord(start).GetString(StaticData::Instance().GetOutputFactorOrder(), false));
start++;
}
return Join(" ", words);
}
static const size_t BLEU_N = 2;
};
/**
* VW thread-specific data about target sentence.
*/