diff --git a/moses/ChartCellLabel.h b/moses/ChartCellLabel.h index e00461249..144a64add 100644 --- a/moses/ChartCellLabel.h +++ b/moses/ChartCellLabel.h @@ -22,6 +22,7 @@ #include "HypoList.h" #include "Word.h" #include "WordsRange.h" +#include "ChartParserCallback.h" namespace search { @@ -52,7 +53,8 @@ public: Stack stack=Stack()) : m_coverage(coverage) , m_label(label) - , m_stack(stack) { + , m_stack(stack) + , m_bestScore(0) { } const WordsRange &GetCoverage() const { @@ -68,6 +70,14 @@ public: return m_stack; } + //caching of best score on stack + float GetBestScore(const ChartParserCallback *outColl) const { + if (m_bestScore == 0) { + m_bestScore = outColl->GetBestScore(this); + } + return m_bestScore; + } + bool operator<(const ChartCellLabel &other) const { // m_coverage and m_label uniquely identify a ChartCellLabel, so don't // need to compare m_stack. @@ -81,6 +91,7 @@ private: const WordsRange &m_coverage; const Word &m_label; Stack m_stack; + mutable float m_bestScore; }; } diff --git a/moses/ChartParserCallback.h b/moses/ChartParserCallback.h index bee69f3d7..ce4af3ab4 100644 --- a/moses/ChartParserCallback.h +++ b/moses/ChartParserCallback.h @@ -12,6 +12,7 @@ class WordsRange; class TargetPhrase; class InputPath; class InputType; +class ChartCellLabel; class ChartParserCallback { @@ -26,7 +27,7 @@ public: virtual void Evaluate(const InputType &input, const InputPath &inputPath) = 0; - virtual float CalcEstimateOfBestScore(const TargetPhraseCollection &, const StackVec &) const = 0; + virtual float GetBestScore(const ChartCellLabel *chartCell) const = 0; }; diff --git a/moses/ChartTranslationOptionList.cpp b/moses/ChartTranslationOptionList.cpp index 963e9089c..e83fbac79 100644 --- a/moses/ChartTranslationOptionList.cpp +++ b/moses/ChartTranslationOptionList.cpp @@ -75,7 +75,11 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc, } } - float score = ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec); + const TargetPhrase &targetPhrase = **(tpc.begin()); + float score = targetPhrase.GetFutureScore(); + for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) { + score += (*p)->GetBestScore(this); + } // If the rule limit has already been reached then don't add the option // unless it is better than at least one existing option. @@ -155,6 +159,15 @@ void ChartTranslationOptionList::ApplyThreshold() m_size = std::distance(m_collection.begin(), bound); } +float ChartTranslationOptionList::GetBestScore(const ChartCellLabel *chartCell) const +{ + const HypoList *stack = chartCell->GetStack().cube; + assert(stack); + assert(!stack->empty()); + const ChartHypothesis &bestHypo = **(stack->begin()); + return bestHypo.GetTotalScore(); +} + void ChartTranslationOptionList::Evaluate(const InputType &input, const InputPath &inputPath) { // NEVER iterate over ALL of the collection. Just over the first m_size diff --git a/moses/ChartTranslationOptionList.h b/moses/ChartTranslationOptionList.h index 04e3c7e2f..d62a89ab3 100644 --- a/moses/ChartTranslationOptionList.h +++ b/moses/ChartTranslationOptionList.h @@ -32,6 +32,7 @@ class TargetPhraseCollection; class WordsRange; class InputType; class InputPath; +class ChartCellLabel; //! a vector of translations options for a specific range, in a specific sentence class ChartTranslationOptionList : public ChartParserCallback @@ -60,9 +61,7 @@ public: return m_size == 0; } - float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const { - return ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec); - } + float GetBestScore(const ChartCellLabel *chartCell) const; void Clear(); void ApplyThreshold(); diff --git a/moses/ChartTranslationOptions.cpp b/moses/ChartTranslationOptions.cpp index 59cbe1463..641d15da5 100644 --- a/moses/ChartTranslationOptions.cpp +++ b/moses/ChartTranslationOptions.cpp @@ -51,23 +51,6 @@ ChartTranslationOptions::~ChartTranslationOptions() } -float ChartTranslationOptions::CalcEstimateOfBestScore( - const TargetPhraseCollection &tpc, - const StackVec &stackVec) -{ - const TargetPhrase &targetPhrase = **(tpc.begin()); - float estimateOfBestScore = targetPhrase.GetFutureScore(); - for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); - ++p) { - const HypoList *stack = (*p)->GetStack().cube; - assert(stack); - assert(!stack->empty()); - const ChartHypothesis &bestHypo = **(stack->begin()); - estimateOfBestScore += bestHypo.GetTotalScore(); - } - return estimateOfBestScore; -} - void ChartTranslationOptions::Evaluate(const InputType &input, const InputPath &inputPath) { SetInputPath(&inputPath); diff --git a/moses/Incremental.cpp b/moses/Incremental.cpp index 2ee98aade..e55cf5e11 100644 --- a/moses/Incremental.cpp +++ b/moses/Incremental.cpp @@ -83,7 +83,7 @@ public: void AddPhraseOOV(TargetPhrase &phrase, std::list &waste_memory, const WordsRange &range); - float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const; + float GetBestScore(const ChartCellLabel *chartCell) const; bool Empty() const { return edges_.Empty(); @@ -124,8 +124,7 @@ template void Fill::Add(const TargetPhraseCollection &targe float below_score = 0.0; for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) { vertices.push_back((*i)->GetStack().incr->RootAlternate()); - if (vertices.back().Empty()) return; - below_score += vertices.back().Bound(); + below_score += (*i)->GetBestScore(this); } std::vector words; @@ -178,15 +177,12 @@ template void Fill::AddPhraseOOV(TargetPhrase &phrase, std: edges_.AddEdge(edge); } -// for early pruning -template float Fill::CalcEstimateOfBestScore(const TargetPhraseCollection &targets, const StackVec &nts) const +// for pruning +template float Fill::GetBestScore(const ChartCellLabel *chartCell) const { - float below_score = 0.0; - for (StackVec::const_iterator i = nts.begin(); i != nts.end(); ++i) { - below_score += (*i)->GetStack().incr->RootAlternate().Bound(); - } - const TargetPhrase &targetPhrase = **(targets.begin()); - return targetPhrase.GetFutureScore() + below_score; + search::PartialVertex vertex = chartCell->GetStack().incr->RootAlternate(); + UTIL_THROW_IF2(vertex.Empty(), "hypothesis with empty stack"); + return vertex.Bound(); } // TODO: factors (but chart doesn't seem to support factors anyway). diff --git a/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp b/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp index 57cd25994..d060ec273 100644 --- a/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp +++ b/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp @@ -41,7 +41,12 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc, return; } - const float score = outColl.CalcEstimateOfBestScore(tpc, stackVec); + const TargetPhrase &targetPhrase = **(tpc.begin()); + float score = targetPhrase.GetFutureScore(); + for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) { + float stackScore = (*p)->GetBestScore(&outColl); + score += stackScore; + } // If the rule limit has already been reached then don't add the option // unless it is better than at least one existing option.