From 866f746f98df5b6390189c889dde686249bf9cb5 Mon Sep 17 00:00:00 2001 From: Phil Williams Date: Mon, 12 Jan 2015 08:47:33 +0000 Subject: [PATCH] s2t decoder: fix scoring Restore scoring function that was commented out in commit 465b4756... --- moses/LM/Ken.cpp | 88 ++++++++++++++++++++++++------------------------ moses/LM/Ken.h | 2 +- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/moses/LM/Ken.cpp b/moses/LM/Ken.cpp index 5a2b75052..76bf12593 100644 --- a/moses/LM/Ken.cpp +++ b/moses/LM/Ken.cpp @@ -360,50 +360,50 @@ template FFState *LanguageModelKen::EvaluateWhenApplied(con return newState; } -//template FFState *LanguageModelKen::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const -//{ -// LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM(); -// lm::ngram::RuleScore ruleScore(*m_ngram, newState->GetChartState()); -// const TargetPhrase &target = *hyperedge.translation; -// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = -// target.GetAlignNonTerm().GetNonTermIndexMap2(); -// -// const size_t size = target.GetSize(); -// size_t phrasePos = 0; -// // Special cases for first word. -// if (size) { -// const Word &word = target.GetWord(0); -// if (word.GetFactor(m_factorType) == m_beginSentenceFactor) { -// // Begin of sentence -// ruleScore.BeginSentence(); -// phrasePos++; -// } else if (word.IsNonTerminal()) { -// // Non-terminal is first so we can copy instead of rescoring. -// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; -// const lm::ngram::ChartState &prevState = static_cast(pred->state[featureID])->GetChartState(); -// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]); -// ruleScore.BeginNonTerminal(prevState, prob); -// phrasePos++; -// } -// } -// -// for (; phrasePos < size; phrasePos++) { -// const Word &word = target.GetWord(phrasePos); -// if (word.IsNonTerminal()) { -// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; -// const lm::ngram::ChartState &prevState = static_cast(pred->state[featureID])->GetChartState(); -// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]); -// ruleScore.NonTerminal(prevState, prob); -// } else { -// ruleScore.Terminal(TranslateID(word)); -// } -// } -// -// float score = ruleScore.Finish(); -// score = TransformLMScore(score); -// accumulator->Assign(this, score); -// return newState; -//} +template FFState *LanguageModelKen::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const +{ + LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM(); + lm::ngram::RuleScore ruleScore(*m_ngram, newState->GetChartState()); + const TargetPhrase &target = *hyperedge.translation; + const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = + target.GetAlignNonTerm().GetNonTermIndexMap2(); + + const size_t size = target.GetSize(); + size_t phrasePos = 0; + // Special cases for first word. + if (size) { + const Word &word = target.GetWord(0); + if (word.GetFactor(m_factorType) == m_beginSentenceFactor) { + // Begin of sentence + ruleScore.BeginSentence(); + phrasePos++; + } else if (word.IsNonTerminal()) { + // Non-terminal is first so we can copy instead of rescoring. + const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; + const lm::ngram::ChartState &prevState = static_cast(pred->state[featureID])->GetChartState(); + float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]); + ruleScore.BeginNonTerminal(prevState, prob); + phrasePos++; + } + } + + for (; phrasePos < size; phrasePos++) { + const Word &word = target.GetWord(phrasePos); + if (word.IsNonTerminal()) { + const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]]; + const lm::ngram::ChartState &prevState = static_cast(pred->state[featureID])->GetChartState(); + float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]); + ruleScore.NonTerminal(prevState, prob); + } else { + ruleScore.Terminal(TranslateID(word)); + } + } + + float score = ruleScore.Finish(); + score = TransformLMScore(score); + accumulator->Assign(this, score); + return newState; +} template void LanguageModelKen::IncrementalCallback(Incremental::Manager &manager) const { diff --git a/moses/LM/Ken.h b/moses/LM/Ken.h index b23027d05..a2fdb6013 100644 --- a/moses/LM/Ken.h +++ b/moses/LM/Ken.h @@ -59,7 +59,7 @@ public: virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; -// virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const; + virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const; virtual void IncrementalCallback(Incremental::Manager &manager) const; virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const;