mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-19 07:07:24 +03:00
Move phrase scoring from LanguageModel to LanguageModelImplementation.
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4324 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
c9995dc44c
commit
16e37adbe0
@ -96,168 +96,10 @@ float LanguageModel::GetOOVWeight() const
|
||||
|
||||
}
|
||||
|
||||
void LanguageModel::CalcScore(const Phrase &phrase
|
||||
, float &fullScore
|
||||
, float &ngramScore
|
||||
, size_t &oovCount) const
|
||||
{
|
||||
fullScore = 0;
|
||||
ngramScore = 0;
|
||||
|
||||
oovCount = 0;
|
||||
|
||||
size_t phraseSize = phrase.GetSize();
|
||||
if (!phraseSize) return;
|
||||
|
||||
vector<const Word*> contextFactor;
|
||||
contextFactor.reserve(GetNGramOrder());
|
||||
std::auto_ptr<FFState> state(m_implementation->NewState((phrase.GetWord(0) == m_implementation->GetSentenceStartArray()) ?
|
||||
m_implementation->GetBeginSentenceState() : m_implementation->GetNullContextState()));
|
||||
size_t currPos = 0;
|
||||
while (currPos < phraseSize) {
|
||||
const Word &word = phrase.GetWord(currPos);
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
// do nothing. reset ngram. needed to score targbet phrases during pt loading in chart decoding
|
||||
if (!contextFactor.empty()) {
|
||||
// TODO: state operator= ?
|
||||
state.reset(m_implementation->NewState(m_implementation->GetNullContextState()));
|
||||
contextFactor.clear();
|
||||
}
|
||||
} else {
|
||||
ShiftOrPush(contextFactor, word);
|
||||
assert(contextFactor.size() <= GetNGramOrder());
|
||||
|
||||
if (word == m_implementation->GetSentenceStartArray()) {
|
||||
// do nothing, don't include prob for <s> unigram
|
||||
assert(currPos == 0);
|
||||
} else {
|
||||
LMResult result = m_implementation->GetValueGivenState(contextFactor, *state);
|
||||
fullScore += result.score;
|
||||
if (contextFactor.size() == GetNGramOrder())
|
||||
ngramScore += result.score;
|
||||
if (contextFactor.size() == 1 && result.unknown)
|
||||
++oovCount;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
currPos++;
|
||||
}
|
||||
}
|
||||
|
||||
void LanguageModel::ShiftOrPush(vector<const Word*> &contextFactor, const Word &word) const
|
||||
{
|
||||
if (contextFactor.size() < GetNGramOrder()) {
|
||||
contextFactor.push_back(&word);
|
||||
} else {
|
||||
// shift
|
||||
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) {
|
||||
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
|
||||
}
|
||||
contextFactor[GetNGramOrder() - 1] = &word;
|
||||
}
|
||||
}
|
||||
|
||||
const FFState* LanguageModel::EmptyHypothesisState(const InputType &/*input*/) const
|
||||
{
|
||||
// This is actually correct. The empty _hypothesis_ has <s> in it. Phrases use m_emptyContextState.
|
||||
return m_implementation->NewState(m_implementation->GetBeginSentenceState());
|
||||
}
|
||||
|
||||
FFState* LanguageModel::Evaluate(
|
||||
const Hypothesis& hypo,
|
||||
const FFState* ps,
|
||||
ScoreComponentCollection* out) const
|
||||
{
|
||||
// In this function, we only compute the LM scores of n-grams that overlap a
|
||||
// phrase boundary. Phrase-internal scores are taken directly from the
|
||||
// translation option.
|
||||
|
||||
// In the case of unigram language models, there is no overlap, so we don't
|
||||
// need to do anything.
|
||||
if(GetNGramOrder() <= 1)
|
||||
return NULL;
|
||||
|
||||
clock_t t = 0;
|
||||
IFVERBOSE(2) {
|
||||
t = clock(); // track time
|
||||
}
|
||||
|
||||
// Empty phrase added? nothing to be done
|
||||
if (hypo.GetCurrTargetLength() == 0)
|
||||
return ps ? m_implementation->NewState(ps) : NULL;
|
||||
|
||||
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
|
||||
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
|
||||
|
||||
// 1st n-gram
|
||||
vector<const Word*> contextFactor(GetNGramOrder());
|
||||
size_t index = 0;
|
||||
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
|
||||
if (currPos >= 0)
|
||||
contextFactor[index++] = &hypo.GetWord(currPos);
|
||||
else {
|
||||
contextFactor[index++] = &m_implementation->GetSentenceStartArray();
|
||||
}
|
||||
}
|
||||
FFState *res = m_implementation->NewState(ps);
|
||||
float lmScore = ps ? m_implementation->GetValueGivenState(contextFactor, *res).score : m_implementation->GetValueForgotState(contextFactor, *res).score;
|
||||
|
||||
// main loop
|
||||
size_t endPos = std::min(startPos + GetNGramOrder() - 2
|
||||
, currEndPos);
|
||||
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
|
||||
// shift all args down 1 place
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||
contextFactor[i] = contextFactor[i + 1];
|
||||
|
||||
// add last factor
|
||||
contextFactor.back() = &hypo.GetWord(currPos);
|
||||
|
||||
lmScore += m_implementation->GetValueGivenState(contextFactor, *res).score;
|
||||
}
|
||||
|
||||
// end of sentence
|
||||
if (hypo.IsSourceCompleted()) {
|
||||
const size_t size = hypo.GetSize();
|
||||
contextFactor.back() = &m_implementation->GetSentenceEndArray();
|
||||
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
|
||||
int currPos = (int)(size - GetNGramOrder() + i + 1);
|
||||
if (currPos < 0)
|
||||
contextFactor[i] = &m_implementation->GetSentenceStartArray();
|
||||
else
|
||||
contextFactor[i] = &hypo.GetWord((size_t)currPos);
|
||||
}
|
||||
lmScore += m_implementation->GetValueForgotState(contextFactor, *res).score;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (endPos < currEndPos) {
|
||||
//need to get the LM state (otherwise the last LM state is fine)
|
||||
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||
contextFactor[i] = contextFactor[i + 1];
|
||||
contextFactor.back() = &hypo.GetWord(currPos);
|
||||
}
|
||||
m_implementation->GetState(contextFactor, *res);
|
||||
}
|
||||
}
|
||||
if (m_enableOOVFeature) {
|
||||
vector<float> scores(2);
|
||||
scores[0] = lmScore;
|
||||
scores[1] = 0;
|
||||
out->PlusEquals(this, scores);
|
||||
} else {
|
||||
out->PlusEquals(this, lmScore);
|
||||
}
|
||||
|
||||
|
||||
IFVERBOSE(2) {
|
||||
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -54,9 +54,6 @@ protected:
|
||||
#endif
|
||||
bool m_enableOOVFeature;
|
||||
|
||||
|
||||
void ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Create a new language model
|
||||
@ -88,11 +85,9 @@ public:
|
||||
* \param ngramScore score of only n-gram of order m_nGramOrder
|
||||
* \param oovCount number of LM OOVs
|
||||
*/
|
||||
void CalcScore(
|
||||
const Phrase &phrase,
|
||||
float &fullScore,
|
||||
float &ngramScore,
|
||||
size_t &oovCount) const;
|
||||
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
|
||||
return m_implementation->CalcScore(phrase, fullScore, ngramScore, oovCount);
|
||||
}
|
||||
|
||||
//! max n-gram order of LM
|
||||
size_t GetNGramOrder() const {
|
||||
@ -103,6 +98,10 @@ public:
|
||||
return m_implementation->GetScoreProducerDescription(idx);
|
||||
}
|
||||
|
||||
bool OOVFeatureEnabled() const {
|
||||
return m_enableOOVFeature;
|
||||
}
|
||||
|
||||
float GetWeight() const;
|
||||
float GetOOVWeight() const;
|
||||
|
||||
@ -120,10 +119,12 @@ public:
|
||||
|
||||
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
|
||||
|
||||
virtual FFState* Evaluate(
|
||||
FFState* Evaluate(
|
||||
const Hypothesis& cur_hypo,
|
||||
const FFState* prev_state,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
ScoreComponentCollection* accumulator) const {
|
||||
return m_implementation->Evaluate(cur_hypo, prev_state, accumulator, this);
|
||||
}
|
||||
|
||||
FFState* EvaluateChart(
|
||||
const ChartHypothesis& cur_hypo,
|
||||
|
@ -68,6 +68,144 @@ void LanguageModelImplementation::GetState(
|
||||
GetValueForgotState(contextFactor, state);
|
||||
}
|
||||
|
||||
// Calculate score of a phrase.
|
||||
void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
|
||||
fullScore = 0;
|
||||
ngramScore = 0;
|
||||
|
||||
oovCount = 0;
|
||||
|
||||
size_t phraseSize = phrase.GetSize();
|
||||
if (!phraseSize) return;
|
||||
|
||||
vector<const Word*> contextFactor;
|
||||
contextFactor.reserve(GetNGramOrder());
|
||||
std::auto_ptr<FFState> state(NewState((phrase.GetWord(0) == GetSentenceStartArray()) ?
|
||||
GetBeginSentenceState() : GetNullContextState()));
|
||||
size_t currPos = 0;
|
||||
while (currPos < phraseSize) {
|
||||
const Word &word = phrase.GetWord(currPos);
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
// do nothing. reset ngram. needed to score target phrases during pt loading in chart decoding
|
||||
if (!contextFactor.empty()) {
|
||||
// TODO: state operator= ?
|
||||
state.reset(NewState(GetNullContextState()));
|
||||
contextFactor.clear();
|
||||
}
|
||||
} else {
|
||||
ShiftOrPush(contextFactor, word);
|
||||
assert(contextFactor.size() <= GetNGramOrder());
|
||||
|
||||
if (word == GetSentenceStartArray()) {
|
||||
// do nothing, don't include prob for <s> unigram
|
||||
assert(currPos == 0);
|
||||
} else {
|
||||
LMResult result = GetValueGivenState(contextFactor, *state);
|
||||
fullScore += result.score;
|
||||
if (contextFactor.size() == GetNGramOrder())
|
||||
ngramScore += result.score;
|
||||
if (contextFactor.size() == 1 && result.unknown)
|
||||
++oovCount;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
currPos++;
|
||||
}
|
||||
}
|
||||
|
||||
FFState *LanguageModelImplementation::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const {
|
||||
// In this function, we only compute the LM scores of n-grams that overlap a
|
||||
// phrase boundary. Phrase-internal scores are taken directly from the
|
||||
// translation option.
|
||||
|
||||
// In the case of unigram language models, there is no overlap, so we don't
|
||||
// need to do anything.
|
||||
if(GetNGramOrder() <= 1)
|
||||
return NULL;
|
||||
|
||||
clock_t t = 0;
|
||||
IFVERBOSE(2) {
|
||||
t = clock(); // track time
|
||||
}
|
||||
|
||||
// Empty phrase added? nothing to be done
|
||||
if (hypo.GetCurrTargetLength() == 0)
|
||||
return ps ? NewState(ps) : NULL;
|
||||
|
||||
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
|
||||
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
|
||||
|
||||
// 1st n-gram
|
||||
vector<const Word*> contextFactor(GetNGramOrder());
|
||||
size_t index = 0;
|
||||
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
|
||||
if (currPos >= 0)
|
||||
contextFactor[index++] = &hypo.GetWord(currPos);
|
||||
else {
|
||||
contextFactor[index++] = &GetSentenceStartArray();
|
||||
}
|
||||
}
|
||||
FFState *res = NewState(ps);
|
||||
float lmScore = ps ? GetValueGivenState(contextFactor, *res).score : GetValueForgotState(contextFactor, *res).score;
|
||||
|
||||
// main loop
|
||||
size_t endPos = std::min(startPos + GetNGramOrder() - 2
|
||||
, currEndPos);
|
||||
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
|
||||
// shift all args down 1 place
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||
contextFactor[i] = contextFactor[i + 1];
|
||||
|
||||
// add last factor
|
||||
contextFactor.back() = &hypo.GetWord(currPos);
|
||||
|
||||
lmScore += GetValueGivenState(contextFactor, *res).score;
|
||||
}
|
||||
|
||||
// end of sentence
|
||||
if (hypo.IsSourceCompleted()) {
|
||||
const size_t size = hypo.GetSize();
|
||||
contextFactor.back() = &GetSentenceEndArray();
|
||||
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
|
||||
int currPos = (int)(size - GetNGramOrder() + i + 1);
|
||||
if (currPos < 0)
|
||||
contextFactor[i] = &GetSentenceStartArray();
|
||||
else
|
||||
contextFactor[i] = &hypo.GetWord((size_t)currPos);
|
||||
}
|
||||
lmScore += GetValueForgotState(contextFactor, *res).score;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (endPos < currEndPos) {
|
||||
//need to get the LM state (otherwise the last LM state is fine)
|
||||
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
|
||||
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
|
||||
contextFactor[i] = contextFactor[i + 1];
|
||||
contextFactor.back() = &hypo.GetWord(currPos);
|
||||
}
|
||||
GetState(contextFactor, *res);
|
||||
}
|
||||
}
|
||||
if (feature->OOVFeatureEnabled()) {
|
||||
vector<float> scores(2);
|
||||
scores[0] = lmScore;
|
||||
scores[1] = 0;
|
||||
out->PlusEquals(feature, scores);
|
||||
} else {
|
||||
out->PlusEquals(feature, lmScore);
|
||||
}
|
||||
|
||||
|
||||
IFVERBOSE(2) {
|
||||
hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t );
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// This is the FFState used by LanguageModelImplementation::EvaluateChart.
|
||||
@ -364,7 +502,7 @@ FFState* LanguageModelImplementation::EvaluateChart(const ChartHypothesis& hypo,
|
||||
return ret;
|
||||
}
|
||||
|
||||
void LanguageModelImplementation::updateChartScore( float *prefixScore, float *finalizedScore, float score, size_t wordPos ) const {
|
||||
void LanguageModelImplementation::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const {
|
||||
if (wordPos < GetNGramOrder()) {
|
||||
*prefixScore += score;
|
||||
}
|
||||
|
@ -99,9 +99,13 @@ public:
|
||||
virtual const FFState *GetBeginSentenceState() const = 0;
|
||||
virtual FFState *NewState(const FFState *from = NULL) const = 0;
|
||||
|
||||
void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
|
||||
|
||||
FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const;
|
||||
|
||||
virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator, const LanguageModel *feature) const;
|
||||
|
||||
void updateChartScore( float *prefixScore, float *finalScore, float score, size_t wordPos ) const;
|
||||
void updateChartScore(float *prefixScore, float *finalScore, float score, size_t wordPos) const;
|
||||
|
||||
//! max n-gram order of LM
|
||||
size_t GetNGramOrder() const {
|
||||
|
Loading…
Reference in New Issue
Block a user