Rule score with left rest (default same)

This commit is contained in:
Kenneth Heafield 2012-10-13 16:49:40 +01:00
parent 018696367b
commit 1b759c169a

View File

@ -176,38 +176,47 @@ template <class Model> void LanguageModelKen<Model>::CalcScore(const Phrase &phr
if (!phrase.GetSize()) return;
typename Model::State state_backing[2];
typename Model::State *state0 = &state_backing[0], *state1 = &state_backing[1];
lm::ngram::ChartState discarded_sadly;
lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly);
size_t position;
if (m_beginSentenceFactor == phrase.GetWord(0).GetFactor(m_factorType)) {
*state0 = m_ngram->BeginSentenceState();
scorer.BeginSentence();
position = 1;
} else {
*state0 = m_ngram->NullContextState();
position = 0;
}
size_t ngramBoundary = m_ngram->Order() - 1;
for (; position < phrase.GetSize(); ++position) {
size_t end_loop = std::min(ngramBoundary, phrase.GetSize());
for (; position < end_loop; ++position) {
const Word &word = phrase.GetWord(position);
if (word.IsNonTerminal()) {
// If there's a non-terminal at 1 and we have a 5-gram LM, then positions 2 3 4 and 5 will be incomplete while position 6 is complete.
ngramBoundary = m_ngram->Order() + position;
*state0 = m_ngram->NullContextState();
fullScore += scorer.Finish();
scorer.Reset();
} else {
lm::WordIndex index = TranslateID(word);
if (index == m_ngram->GetVocabulary().BeginSentence()) {
std::cerr << "Either your data contains <s> in a position other than the first word or your language model is missing <s>. Did you build your ARPA using IRSTLM and forget to run add-start-end.sh?" << std::endl;
abort();
}
float score = TransformLMScore(m_ngram->Score(*state0, index, *state1));
std::swap(state0, state1);
if (position >= ngramBoundary) ngramScore += score;
fullScore += score;
scorer.Terminal(index);
if (!index) ++oovCount;
}
}
float before_boundary = fullScore + scorer.Finish();
for (; position < phrase.GetSize(); ++position) {
const Word &word = phrase.GetWord(position);
if (word.IsNonTerminal()) {
fullScore += scorer.Finish();
scorer.Reset();
} else {
lm::WordIndex index = TranslateID(word);
scorer.Terminal(index);
if (!index) ++oovCount;
}
}
fullScore += scorer.Finish();
ngramScore = TransformLMScore(fullScore - before_boundary);
fullScore = TransformLMScore(fullScore);
}
template <class Model> FFState *LanguageModelKen<Model>::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const {