mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-19 07:07:24 +03:00
Move phrase scoring from LanguageModel to LanguageModelImplementation.
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4324 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
c9995dc44c
commit
16e37adbe0
@ -96,168 +96,10 @@ float LanguageModel::GetOOVWeight() const
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LanguageModel::CalcScore(const Phrase &phrase
|
|
||||||
, float &fullScore
|
|
||||||
, float &ngramScore
|
|
||||||
, size_t &oovCount) const
|
|
||||||
{
|
|
||||||
fullScore = 0;
|
|
||||||
ngramScore = 0;
|
|
||||||
|
|
||||||
oovCount = 0;
|
|
||||||
|
|
||||||
size_t phraseSize = phrase.GetSize();
|
|
||||||
if (!phraseSize) return;
|
|
||||||
|
|
||||||
vector<const Word*> contextFactor;
|
|
||||||
contextFactor.reserve(GetNGramOrder());
|
|
||||||
std::auto_ptr<FFState> state(m_implementation->NewState((phrase.GetWord(0) == m_implementation->GetSentenceStartArray()) ?
|
|
||||||
m_implementation->GetBeginSentenceState() : m_implementation->GetNullContextState()));
|
|
||||||
size_t currPos = 0;
|
|
||||||
while (currPos < phraseSize) {
|
|
||||||
const Word &word = phrase.GetWord(currPos);
|
|
||||||
|
|
||||||
if (word.IsNonTerminal()) {
|
|
||||||
// do nothing. reset ngram. needed to score targbet phrases during pt loading in chart decoding
|
|
||||||
if (!contextFactor.empty()) {
|
|
||||||
// TODO: state operator= ?
|
|
||||||
state.reset(m_implementation->NewState(m_implementation->GetNullContextState()));
|
|
||||||
contextFactor.clear();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ShiftOrPush(contextFactor, word);
|
|
||||||
assert(contextFactor.size() <= GetNGramOrder());
|
|
||||||
|
|
||||||
if (word == m_implementation->GetSentenceStartArray()) {
|
|
||||||
// do nothing, don't include prob for <s> unigram
|
|
||||||
assert(currPos == 0);
|
|
||||||
} else {
|
|
||||||
LMResult result = m_implementation->GetValueGivenState(contextFactor, *state);
|
|
||||||
fullScore += result.score;
|
|
||||||
if (contextFactor.size() == GetNGramOrder())
|
|
||||||
ngramScore += result.score;
|
|
||||||
if (contextFactor.size() == 1 && result.unknown)
|
|
||||||
++oovCount;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
currPos++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LanguageModel::ShiftOrPush(vector<const Word*> &contextFactor, const Word &word) const
|
|
||||||
{
|
|
||||||
if (contextFactor.size() < GetNGramOrder()) {
|
|
||||||
contextFactor.push_back(&word);
|
|
||||||
} else {
|
|
||||||
// shift
|
|
||||||
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) {
|
|
||||||
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
|
|
||||||
}
|
|
||||||
contextFactor[GetNGramOrder() - 1] = &word;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const FFState* LanguageModel::EmptyHypothesisState(const InputType &/*input*/) const
|
const FFState* LanguageModel::EmptyHypothesisState(const InputType &/*input*/) const
|
||||||
{
|
{
|
||||||
// This is actually correct. The empty _hypothesis_ has <s> in it. Phrases use m_emptyContextState.
|
// This is actually correct. The empty _hypothesis_ has <s> in it. Phrases use m_emptyContextState.
|
||||||
return m_implementation->NewState(m_implementation->GetBeginSentenceState());
|
return m_implementation->NewState(m_implementation->GetBeginSentenceState());
|
||||||
}
|
}
|
||||||
|
|
||||||
FFState* LanguageModel::Evaluate(
|
|
||||||
const Hypothesis& hypo,
|
|
||||||
const FFState* ps,
|
|
||||||
ScoreComponentCollection* out) const
|
|
||||||
{
|
|
||||||
// In this function, we only compute the LM scores of n-grams that overlap a
|
|
||||||
// phrase boundary. Phrase-internal scores are taken directly from the
|
|
||||||
// translation option.
|
|
||||||
|
|
||||||
// In the case of unigram language models, there is no overlap, so we don't
|
|
||||||
// need to do anything.
|
|
||||||
if(GetNGramOrder() <= 1)
|
|
||||||
return NULL;
|
|
||||||
|
|
||||||
clock_t t = 0;
|
|
||||||
IFVERBOSE(2) {
|
|
||||||
t = clock(); // track time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty phrase added? nothing to be done
|
|
||||||
if (hypo.GetCurrTargetLength() == 0)
|
|
||||||
return ps ? m_implementation->NewState(ps) : NULL;
|
|
||||||
|
|
||||||
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
|
|
||||||
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
|
|
||||||
|
|
||||||
// 1st n-gram
|
|
||||||
vector<const Word*> contextFactor(GetNGramOrder());
|
|
||||||
size_t index = 0;
|
|
||||||
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
|
|
||||||
if (currPos >= 0)
|
|
||||||
contextFactor[index++] = &hypo.GetWord(currPos);
|
|
||||||
else {
|
|
||||||
contextFactor[index++] = &m_implementation->GetSentenceStartArray();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
FFState *res = m_implementation->NewState(ps);
|
|
||||||
float lmScore = ps ? m_implementation->GetValueGivenState(contextFactor, *res).score : m_implementation->GetValueForgotState(contextFactor, *res).score;
|
|
||||||
|
|
||||||
// main loop
|
|
||||||
size_t endPos = std::min(startPos + GetNGramOrder() - 2
|
|
||||||
, currEndPos);
|
|
||||||
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
|
|
||||||
// shift all args down 1 place
|
|
||||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
|
||||||
contextFactor[i] = contextFactor[i + 1];
|
|
||||||
|
|
||||||
// add last factor
|
|
||||||
contextFactor.back() = &hypo.GetWord(currPos);
|
|
||||||
|
|
||||||
lmScore += m_implementation->GetValueGivenState(contextFactor, *res).score;
|
|
||||||
}
|
|
||||||
|
|
||||||
// end of sentence
|
|
||||||
if (hypo.IsSourceCompleted()) {
|
|
||||||
const size_t size = hypo.GetSize();
|
|
||||||
contextFactor.back() = &m_implementation->GetSentenceEndArray();
|
|
||||||
|
|
||||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
|
|
||||||
int currPos = (int)(size - GetNGramOrder() + i + 1);
|
|
||||||
if (currPos < 0)
|
|
||||||
contextFactor[i] = &m_implementation->GetSentenceStartArray();
|
|
||||||
else
|
|
||||||
contextFactor[i] = &hypo.GetWord((size_t)currPos);
|
|
||||||
}
|
|
||||||
lmScore += m_implementation->GetValueForgotState(contextFactor, *res).score;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (endPos < currEndPos) {
|
|
||||||
//need to get the LM state (otherwise the last LM state is fine)
|
|
||||||
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
|
|
||||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
|
||||||
contextFactor[i] = contextFactor[i + 1];
|
|
||||||
contextFactor.back() = &hypo.GetWord(currPos);
|
|
||||||
}
|
|
||||||
m_implementation->GetState(contextFactor, *res);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (m_enableOOVFeature) {
|
|
||||||
vector<float> scores(2);
|
|
||||||
scores[0] = lmScore;
|
|
||||||
scores[1] = 0;
|
|
||||||
out->PlusEquals(this, scores);
|
|
||||||
} else {
|
|
||||||
out->PlusEquals(this, lmScore);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
IFVERBOSE(2) {
|
|
||||||
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -54,9 +54,6 @@ protected:
|
|||||||
#endif
|
#endif
|
||||||
bool m_enableOOVFeature;
|
bool m_enableOOVFeature;
|
||||||
|
|
||||||
|
|
||||||
void ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* Create a new language model
|
* Create a new language model
|
||||||
@ -88,11 +85,9 @@ public:
|
|||||||
* \param ngramScore score of only n-gram of order m_nGramOrder
|
* \param ngramScore score of only n-gram of order m_nGramOrder
|
||||||
* \param oovCount number of LM OOVs
|
* \param oovCount number of LM OOVs
|
||||||
*/
|
*/
|
||||||
void CalcScore(
|
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
|
||||||
const Phrase &phrase,
|
return m_implementation->CalcScore(phrase, fullScore, ngramScore, oovCount);
|
||||||
float &fullScore,
|
}
|
||||||
float &ngramScore,
|
|
||||||
size_t &oovCount) const;
|
|
||||||
|
|
||||||
//! max n-gram order of LM
|
//! max n-gram order of LM
|
||||||
size_t GetNGramOrder() const {
|
size_t GetNGramOrder() const {
|
||||||
@ -103,6 +98,10 @@ public:
|
|||||||
return m_implementation->GetScoreProducerDescription(idx);
|
return m_implementation->GetScoreProducerDescription(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OOVFeatureEnabled() const {
|
||||||
|
return m_enableOOVFeature;
|
||||||
|
}
|
||||||
|
|
||||||
float GetWeight() const;
|
float GetWeight() const;
|
||||||
float GetOOVWeight() const;
|
float GetOOVWeight() const;
|
||||||
|
|
||||||
@ -120,10 +119,12 @@ public:
|
|||||||
|
|
||||||
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
|
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
|
||||||
|
|
||||||
virtual FFState* Evaluate(
|
FFState* Evaluate(
|
||||||
const Hypothesis& cur_hypo,
|
const Hypothesis& cur_hypo,
|
||||||
const FFState* prev_state,
|
const FFState* prev_state,
|
||||||
ScoreComponentCollection* accumulator) const;
|
ScoreComponentCollection* accumulator) const {
|
||||||
|
return m_implementation->Evaluate(cur_hypo, prev_state, accumulator, this);
|
||||||
|
}
|
||||||
|
|
||||||
FFState* EvaluateChart(
|
FFState* EvaluateChart(
|
||||||
const ChartHypothesis& cur_hypo,
|
const ChartHypothesis& cur_hypo,
|
||||||
|
@ -68,6 +68,144 @@ void LanguageModelImplementation::GetState(
|
|||||||
GetValueForgotState(contextFactor, state);
|
GetValueForgotState(contextFactor, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate score of a phrase.
|
||||||
|
void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
|
||||||
|
fullScore = 0;
|
||||||
|
ngramScore = 0;
|
||||||
|
|
||||||
|
oovCount = 0;
|
||||||
|
|
||||||
|
size_t phraseSize = phrase.GetSize();
|
||||||
|
if (!phraseSize) return;
|
||||||
|
|
||||||
|
vector<const Word*> contextFactor;
|
||||||
|
contextFactor.reserve(GetNGramOrder());
|
||||||
|
std::auto_ptr<FFState> state(NewState((phrase.GetWord(0) == GetSentenceStartArray()) ?
|
||||||
|
GetBeginSentenceState() : GetNullContextState()));
|
||||||
|
size_t currPos = 0;
|
||||||
|
while (currPos < phraseSize) {
|
||||||
|
const Word &word = phrase.GetWord(currPos);
|
||||||
|
|
||||||
|
if (word.IsNonTerminal()) {
|
||||||
|
// do nothing. reset ngram. needed to score target phrases during pt loading in chart decoding
|
||||||
|
if (!contextFactor.empty()) {
|
||||||
|
// TODO: state operator= ?
|
||||||
|
state.reset(NewState(GetNullContextState()));
|
||||||
|
contextFactor.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ShiftOrPush(contextFactor, word);
|
||||||
|
assert(contextFactor.size() <= GetNGramOrder());
|
||||||
|
|
||||||
|
if (word == GetSentenceStartArray()) {
|
||||||
|
// do nothing, don't include prob for <s> unigram
|
||||||
|
assert(currPos == 0);
|
||||||
|
} else {
|
||||||
|
LMResult result = GetValueGivenState(contextFactor, *state);
|
||||||
|
fullScore += result.score;
|
||||||
|
if (contextFactor.size() == GetNGramOrder())
|
||||||
|
ngramScore += result.score;
|
||||||
|
if (contextFactor.size() == 1 && result.unknown)
|
||||||
|
++oovCount;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currPos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FFState *LanguageModelImplementation::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const {
|
||||||
|
// In this function, we only compute the LM scores of n-grams that overlap a
|
||||||
|
// phrase boundary. Phrase-internal scores are taken directly from the
|
||||||
|
// translation option.
|
||||||
|
|
||||||
|
// In the case of unigram language models, there is no overlap, so we don't
|
||||||
|
// need to do anything.
|
||||||
|
if(GetNGramOrder() <= 1)
|
||||||
|
return NULL;
|
||||||
|
|
||||||
|
clock_t t = 0;
|
||||||
|
IFVERBOSE(2) {
|
||||||
|
t = clock(); // track time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty phrase added? nothing to be done
|
||||||
|
if (hypo.GetCurrTargetLength() == 0)
|
||||||
|
return ps ? NewState(ps) : NULL;
|
||||||
|
|
||||||
|
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
|
||||||
|
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
|
||||||
|
|
||||||
|
// 1st n-gram
|
||||||
|
vector<const Word*> contextFactor(GetNGramOrder());
|
||||||
|
size_t index = 0;
|
||||||
|
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
|
||||||
|
if (currPos >= 0)
|
||||||
|
contextFactor[index++] = &hypo.GetWord(currPos);
|
||||||
|
else {
|
||||||
|
contextFactor[index++] = &GetSentenceStartArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FFState *res = NewState(ps);
|
||||||
|
float lmScore = ps ? GetValueGivenState(contextFactor, *res).score : GetValueForgotState(contextFactor, *res).score;
|
||||||
|
|
||||||
|
// main loop
|
||||||
|
size_t endPos = std::min(startPos + GetNGramOrder() - 2
|
||||||
|
, currEndPos);
|
||||||
|
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
|
||||||
|
// shift all args down 1 place
|
||||||
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||||
|
contextFactor[i] = contextFactor[i + 1];
|
||||||
|
|
||||||
|
// add last factor
|
||||||
|
contextFactor.back() = &hypo.GetWord(currPos);
|
||||||
|
|
||||||
|
lmScore += GetValueGivenState(contextFactor, *res).score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// end of sentence
|
||||||
|
if (hypo.IsSourceCompleted()) {
|
||||||
|
const size_t size = hypo.GetSize();
|
||||||
|
contextFactor.back() = &GetSentenceEndArray();
|
||||||
|
|
||||||
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
|
||||||
|
int currPos = (int)(size - GetNGramOrder() + i + 1);
|
||||||
|
if (currPos < 0)
|
||||||
|
contextFactor[i] = &GetSentenceStartArray();
|
||||||
|
else
|
||||||
|
contextFactor[i] = &hypo.GetWord((size_t)currPos);
|
||||||
|
}
|
||||||
|
lmScore += GetValueForgotState(contextFactor, *res).score;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (endPos < currEndPos) {
|
||||||
|
//need to get the LM state (otherwise the last LM state is fine)
|
||||||
|
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
|
||||||
|
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||||
|
contextFactor[i] = contextFactor[i + 1];
|
||||||
|
contextFactor.back() = &hypo.GetWord(currPos);
|
||||||
|
}
|
||||||
|
GetState(contextFactor, *res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (feature->OOVFeatureEnabled()) {
|
||||||
|
vector<float> scores(2);
|
||||||
|
scores[0] = lmScore;
|
||||||
|
scores[1] = 0;
|
||||||
|
out->PlusEquals(feature, scores);
|
||||||
|
} else {
|
||||||
|
out->PlusEquals(feature, lmScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
IFVERBOSE(2) {
|
||||||
|
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// This is the FFState used by LanguageModelImplementation::EvaluateChart.
|
// This is the FFState used by LanguageModelImplementation::EvaluateChart.
|
||||||
@ -364,7 +502,7 @@ FFState* LanguageModelImplementation::EvaluateChart(const ChartHypothesis& hypo,
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LanguageModelImplementation::updateChartScore( float *prefixScore, float *finalizedScore, float score, size_t wordPos ) const {
|
void LanguageModelImplementation::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const {
|
||||||
if (wordPos < GetNGramOrder()) {
|
if (wordPos < GetNGramOrder()) {
|
||||||
*prefixScore += score;
|
*prefixScore += score;
|
||||||
}
|
}
|
||||||
|
@ -99,9 +99,13 @@ public:
|
|||||||
virtual const FFState *GetBeginSentenceState() const = 0;
|
virtual const FFState *GetBeginSentenceState() const = 0;
|
||||||
virtual FFState *NewState(const FFState *from = NULL) const = 0;
|
virtual FFState *NewState(const FFState *from = NULL) const = 0;
|
||||||
|
|
||||||
|
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
|
||||||
|
|
||||||
|
FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const;
|
||||||
|
|
||||||
virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator, const LanguageModel *feature) const;
|
virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator, const LanguageModel *feature) const;
|
||||||
|
|
||||||
void updateChartScore( float *prefixScore, float *finalScore, float score, size_t wordPos ) const;
|
void updateChartScore(float *prefixScore, float *finalScore, float score, size_t wordPos) const;
|
||||||
|
|
||||||
//! max n-gram order of LM
|
//! max n-gram order of LM
|
||||||
size_t GetNGramOrder() const {
|
size_t GetNGramOrder() const {
|
||||||
|
Loading…
Reference in New Issue
Block a user