Move phrase scoring from LanguageModel to LanguageModelImplementation.

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4324 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
heafield 2011-10-11 13:50:44 +00:00
parent c9995dc44c
commit 16e37adbe0
4 changed files with 155 additions and 170 deletions

View File

@ -96,168 +96,10 @@ float LanguageModel::GetOOVWeight() const
}
void LanguageModel::CalcScore(const Phrase &phrase
, float &fullScore
, float &ngramScore
, size_t &oovCount) const
{
fullScore = 0;
ngramScore = 0;
oovCount = 0;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
vector<const Word*> contextFactor;
contextFactor.reserve(GetNGramOrder());
std::auto_ptr<FFState> state(m_implementation->NewState((phrase.GetWord(0) == m_implementation->GetSentenceStartArray()) ?
m_implementation->GetBeginSentenceState() : m_implementation->GetNullContextState()));
size_t currPos = 0;
while (currPos < phraseSize) {
const Word &word = phrase.GetWord(currPos);
if (word.IsNonTerminal()) {
// do nothing. reset ngram. needed to score targbet phrases during pt loading in chart decoding
if (!contextFactor.empty()) {
// TODO: state operator= ?
state.reset(m_implementation->NewState(m_implementation->GetNullContextState()));
contextFactor.clear();
}
} else {
ShiftOrPush(contextFactor, word);
assert(contextFactor.size() <= GetNGramOrder());
if (word == m_implementation->GetSentenceStartArray()) {
// do nothing, don't include prob for <s> unigram
assert(currPos == 0);
} else {
LMResult result = m_implementation->GetValueGivenState(contextFactor, *state);
fullScore += result.score;
if (contextFactor.size() == GetNGramOrder())
ngramScore += result.score;
if (contextFactor.size() == 1 && result.unknown)
++oovCount;
}
}
currPos++;
}
}
void LanguageModel::ShiftOrPush(vector<const Word*> &contextFactor, const Word &word) const
{
if (contextFactor.size() < GetNGramOrder()) {
contextFactor.push_back(&word);
} else {
// shift
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) {
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
}
contextFactor[GetNGramOrder() - 1] = &word;
}
}
const FFState* LanguageModel::EmptyHypothesisState(const InputType &/*input*/) const
{
// This is actually correct. The empty _hypothesis_ has <s> in it. Phrases use m_emptyContextState.
return m_implementation->NewState(m_implementation->GetBeginSentenceState());
}
FFState* LanguageModel::Evaluate(
const Hypothesis& hypo,
const FFState* ps,
ScoreComponentCollection* out) const
{
// In this function, we only compute the LM scores of n-grams that overlap a
// phrase boundary. Phrase-internal scores are taken directly from the
// translation option.
// In the case of unigram language models, there is no overlap, so we don't
// need to do anything.
if(GetNGramOrder() <= 1)
return NULL;
clock_t t = 0;
IFVERBOSE(2) {
t = clock(); // track time
}
// Empty phrase added? nothing to be done
if (hypo.GetCurrTargetLength() == 0)
return ps ? m_implementation->NewState(ps) : NULL;
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
// 1st n-gram
vector<const Word*> contextFactor(GetNGramOrder());
size_t index = 0;
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
if (currPos >= 0)
contextFactor[index++] = &hypo.GetWord(currPos);
else {
contextFactor[index++] = &m_implementation->GetSentenceStartArray();
}
}
FFState *res = m_implementation->NewState(ps);
float lmScore = ps ? m_implementation->GetValueGivenState(contextFactor, *res).score : m_implementation->GetValueForgotState(contextFactor, *res).score;
// main loop
size_t endPos = std::min(startPos + GetNGramOrder() - 2
, currEndPos);
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
// shift all args down 1 place
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
// add last factor
contextFactor.back() = &hypo.GetWord(currPos);
lmScore += m_implementation->GetValueGivenState(contextFactor, *res).score;
}
// end of sentence
if (hypo.IsSourceCompleted()) {
const size_t size = hypo.GetSize();
contextFactor.back() = &m_implementation->GetSentenceEndArray();
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
int currPos = (int)(size - GetNGramOrder() + i + 1);
if (currPos < 0)
contextFactor[i] = &m_implementation->GetSentenceStartArray();
else
contextFactor[i] = &hypo.GetWord((size_t)currPos);
}
lmScore += m_implementation->GetValueForgotState(contextFactor, *res).score;
}
else
{
if (endPos < currEndPos) {
//need to get the LM state (otherwise the last LM state is fine)
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &hypo.GetWord(currPos);
}
m_implementation->GetState(contextFactor, *res);
}
}
if (m_enableOOVFeature) {
vector<float> scores(2);
scores[0] = lmScore;
scores[1] = 0;
out->PlusEquals(this, scores);
} else {
out->PlusEquals(this, lmScore);
}
IFVERBOSE(2) {
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
}
return res;
}
}

View File

@ -54,9 +54,6 @@ protected:
#endif
bool m_enableOOVFeature;
void ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const;
public:
/**
* Create a new language model
@ -88,11 +85,9 @@ public:
* \param ngramScore score of only n-gram of order m_nGramOrder
* \param oovCount number of LM OOVs
*/
void CalcScore(
const Phrase &phrase,
float &fullScore,
float &ngramScore,
size_t &oovCount) const;
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
return m_implementation->CalcScore(phrase, fullScore, ngramScore, oovCount);
}
//! max n-gram order of LM
size_t GetNGramOrder() const {
@ -103,6 +98,10 @@ public:
return m_implementation->GetScoreProducerDescription(idx);
}
bool OOVFeatureEnabled() const {
return m_enableOOVFeature;
}
float GetWeight() const;
float GetOOVWeight() const;
@ -120,10 +119,12 @@ public:
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
virtual FFState* Evaluate(
FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
ScoreComponentCollection* accumulator) const {
return m_implementation->Evaluate(cur_hypo, prev_state, accumulator, this);
}
FFState* EvaluateChart(
const ChartHypothesis& cur_hypo,

View File

@ -68,6 +68,144 @@ void LanguageModelImplementation::GetState(
GetValueForgotState(contextFactor, state);
}
// Calculate score of a phrase.
void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
fullScore = 0;
ngramScore = 0;
oovCount = 0;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
vector<const Word*> contextFactor;
contextFactor.reserve(GetNGramOrder());
std::auto_ptr<FFState> state(NewState((phrase.GetWord(0) == GetSentenceStartArray()) ?
GetBeginSentenceState() : GetNullContextState()));
size_t currPos = 0;
while (currPos < phraseSize) {
const Word &word = phrase.GetWord(currPos);
if (word.IsNonTerminal()) {
// do nothing. reset ngram. needed to score target phrases during pt loading in chart decoding
if (!contextFactor.empty()) {
// TODO: state operator= ?
state.reset(NewState(GetNullContextState()));
contextFactor.clear();
}
} else {
ShiftOrPush(contextFactor, word);
assert(contextFactor.size() <= GetNGramOrder());
if (word == GetSentenceStartArray()) {
// do nothing, don't include prob for <s> unigram
assert(currPos == 0);
} else {
LMResult result = GetValueGivenState(contextFactor, *state);
fullScore += result.score;
if (contextFactor.size() == GetNGramOrder())
ngramScore += result.score;
if (contextFactor.size() == 1 && result.unknown)
++oovCount;
}
}
currPos++;
}
}
FFState *LanguageModelImplementation::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const {
// In this function, we only compute the LM scores of n-grams that overlap a
// phrase boundary. Phrase-internal scores are taken directly from the
// translation option.
// In the case of unigram language models, there is no overlap, so we don't
// need to do anything.
if(GetNGramOrder() <= 1)
return NULL;
clock_t t = 0;
IFVERBOSE(2) {
t = clock(); // track time
}
// Empty phrase added? nothing to be done
if (hypo.GetCurrTargetLength() == 0)
return ps ? NewState(ps) : NULL;
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
// 1st n-gram
vector<const Word*> contextFactor(GetNGramOrder());
size_t index = 0;
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
if (currPos >= 0)
contextFactor[index++] = &hypo.GetWord(currPos);
else {
contextFactor[index++] = &GetSentenceStartArray();
}
}
FFState *res = NewState(ps);
float lmScore = ps ? GetValueGivenState(contextFactor, *res).score : GetValueForgotState(contextFactor, *res).score;
// main loop
size_t endPos = std::min(startPos + GetNGramOrder() - 2
, currEndPos);
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
// shift all args down 1 place
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
// add last factor
contextFactor.back() = &hypo.GetWord(currPos);
lmScore += GetValueGivenState(contextFactor, *res).score;
}
// end of sentence
if (hypo.IsSourceCompleted()) {
const size_t size = hypo.GetSize();
contextFactor.back() = &GetSentenceEndArray();
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
int currPos = (int)(size - GetNGramOrder() + i + 1);
if (currPos < 0)
contextFactor[i] = &GetSentenceStartArray();
else
contextFactor[i] = &hypo.GetWord((size_t)currPos);
}
lmScore += GetValueForgotState(contextFactor, *res).score;
}
else
{
if (endPos < currEndPos) {
//need to get the LM state (otherwise the last LM state is fine)
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &hypo.GetWord(currPos);
}
GetState(contextFactor, *res);
}
}
if (feature->OOVFeatureEnabled()) {
vector<float> scores(2);
scores[0] = lmScore;
scores[1] = 0;
out->PlusEquals(feature, scores);
} else {
out->PlusEquals(feature, lmScore);
}
IFVERBOSE(2) {
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
}
return res;
}
namespace {
// This is the FFState used by LanguageModelImplementation::EvaluateChart.
@ -364,7 +502,7 @@ FFState* LanguageModelImplementation::EvaluateChart(const ChartHypothesis& hypo,
return ret;
}
void LanguageModelImplementation::updateChartScore( float *prefixScore, float *finalizedScore, float score, size_t wordPos ) const {
void LanguageModelImplementation::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const {
if (wordPos < GetNGramOrder()) {
*prefixScore += score;
}

View File

@ -99,9 +99,13 @@ public:
virtual const FFState *GetBeginSentenceState() const = 0;
virtual FFState *NewState(const FFState *from = NULL) const = 0;
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const;
virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator, const LanguageModel *feature) const;
void updateChartScore( float *prefixScore, float *finalScore, float score, size_t wordPos ) const;
void updateChartScore(float *prefixScore, float *finalScore, float score, size_t wordPos) const;
//! max n-gram order of LM
size_t GetNGramOrder() const {