mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-19 23:27:46 +03:00
minor refactoring in VW feature
This commit is contained in:
parent
34649b74d3
commit
ff1cae919b
82
moses/FF/VW/TrainingLoss.h
Normal file
82
moses/FF/VW/TrainingLoss.h
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "moses/Util.h"
|
||||||
|
#include "moses/StaticData.h"
|
||||||
|
#include "moses/Phrase.h"
|
||||||
|
|
||||||
|
namespace Moses
|
||||||
|
{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculation of training loss for VW.
|
||||||
|
*/
|
||||||
|
class TrainingLoss
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic 1/0 training loss.
|
||||||
|
*/
|
||||||
|
class TrainingLossBasic : public TrainingLoss
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
||||||
|
return isCorrect ? 0.0 : 1.0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BLEU2+1 training loss.
|
||||||
|
*/
|
||||||
|
class TrainingLossBLEU : public TrainingLoss
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
||||||
|
std::multiset<std::string> refNgrams;
|
||||||
|
float precision = 1.0;
|
||||||
|
|
||||||
|
for (size_t size = 1; size <= BLEU_N; size++) {
|
||||||
|
for (int pos = 0; pos <= (int)correct.GetSize() - (int)size; pos++) {
|
||||||
|
refNgrams.insert(MakeNGram(correct, pos, pos + size));
|
||||||
|
}
|
||||||
|
|
||||||
|
int confirmed = 1; // we're BLEU+1
|
||||||
|
int total = 1;
|
||||||
|
for (int pos = 0; pos <= (int)candidate.GetSize() - (int)size; pos++) {
|
||||||
|
total++;
|
||||||
|
std::string ngram = MakeNGram(candidate, pos, pos + size);
|
||||||
|
std::multiset<std::string>::iterator it;
|
||||||
|
if ((it = refNgrams.find(ngram)) != refNgrams.end()) {
|
||||||
|
confirmed++;
|
||||||
|
refNgrams.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
precision *= (float)confirmed / total;
|
||||||
|
}
|
||||||
|
|
||||||
|
float brevityPenalty = exp((float)(1.0 - correct.GetSize()) / candidate.GetSize());
|
||||||
|
|
||||||
|
return 1.0 - brevityPenalty * pow(precision, (float)1.0 / BLEU_N);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string MakeNGram(const TargetPhrase &phrase, size_t start, size_t end) const {
|
||||||
|
std::vector<std::string> words;
|
||||||
|
while (start != end) {
|
||||||
|
words.push_back(phrase.GetWord(start).GetString(StaticData::Instance().GetOutputFactorOrder(), false));
|
||||||
|
start++;
|
||||||
|
}
|
||||||
|
return Join(" ", words);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const size_t BLEU_N = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -3,8 +3,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <set>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "moses/FF/StatelessFeatureFunction.h"
|
#include "moses/FF/StatelessFeatureFunction.h"
|
||||||
#include "moses/PP/CountsPhraseProperty.h"
|
#include "moses/PP/CountsPhraseProperty.h"
|
||||||
@ -21,6 +19,7 @@
|
|||||||
#include "VWFeatureBase.h"
|
#include "VWFeatureBase.h"
|
||||||
#include "TabbedSentence.h"
|
#include "TabbedSentence.h"
|
||||||
#include "ThreadLocalByFeatureStorage.h"
|
#include "ThreadLocalByFeatureStorage.h"
|
||||||
|
#include "TrainingLoss.h"
|
||||||
|
|
||||||
namespace Moses
|
namespace Moses
|
||||||
{
|
{
|
||||||
@ -61,73 +60,6 @@ private:
|
|||||||
int m_min, m_max;
|
int m_min, m_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculation of training loss.
|
|
||||||
*/
|
|
||||||
class TrainingLoss
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Basic 1/0 training loss.
|
|
||||||
*/
|
|
||||||
class TrainingLossBasic : public TrainingLoss
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
|
||||||
return isCorrect ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BLEU2+1 training loss.
|
|
||||||
*/
|
|
||||||
class TrainingLossBLEU : public TrainingLoss
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
|
|
||||||
std::multiset<std::string> refNgrams;
|
|
||||||
float precision = 1.0;
|
|
||||||
|
|
||||||
for (size_t size = 1; size <= BLEU_N; size++) {
|
|
||||||
for (int pos = 0; pos <= (int)correct.GetSize() - (int)size; pos++) {
|
|
||||||
refNgrams.insert(MakeNGram(correct, pos, pos + size));
|
|
||||||
}
|
|
||||||
|
|
||||||
int confirmed = 1; // we're BLEU+1
|
|
||||||
int total = 1;
|
|
||||||
for (int pos = 0; pos <= (int)candidate.GetSize() - (int)size; pos++) {
|
|
||||||
total++;
|
|
||||||
std::string ngram = MakeNGram(candidate, pos, pos + size);
|
|
||||||
std::multiset<std::string>::iterator it;
|
|
||||||
if ((it = refNgrams.find(ngram)) != refNgrams.end()) {
|
|
||||||
confirmed++;
|
|
||||||
refNgrams.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
precision *= (float)confirmed / total;
|
|
||||||
}
|
|
||||||
|
|
||||||
float brevityPenalty = exp((float)(1.0 - correct.GetSize()) / candidate.GetSize());
|
|
||||||
|
|
||||||
return 1.0 - brevityPenalty * pow(precision, (float)1.0 / BLEU_N);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string MakeNGram(const TargetPhrase &phrase, size_t start, size_t end) const {
|
|
||||||
std::vector<std::string> words;
|
|
||||||
while (start != end) {
|
|
||||||
words.push_back(phrase.GetWord(start).GetString(StaticData::Instance().GetOutputFactorOrder(), false));
|
|
||||||
start++;
|
|
||||||
}
|
|
||||||
return Join(" ", words);
|
|
||||||
}
|
|
||||||
|
|
||||||
static const size_t BLEU_N = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* VW thread-specific data about target sentence.
|
* VW thread-specific data about target sentence.
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user