faster pruning in chart decoding

This commit is contained in:
Rico Sennrich 2014-03-26 11:23:23 +00:00
parent c8682e9420
commit 1f435340f0
7 changed files with 43 additions and 35 deletions

View File

@ -22,6 +22,7 @@
#include "HypoList.h"
#include "Word.h"
#include "WordsRange.h"
#include "ChartParserCallback.h"
namespace search
{
@ -52,7 +53,8 @@ public:
Stack stack=Stack())
: m_coverage(coverage)
, m_label(label)
, m_stack(stack) {
, m_stack(stack)
, m_bestScore(0) {
}
const WordsRange &GetCoverage() const {
@ -68,6 +70,14 @@ public:
return m_stack;
}
//caching of best score on stack
float GetBestScore(const ChartParserCallback *outColl) const {
if (m_bestScore == 0) {
m_bestScore = outColl->GetBestScore(this);
}
return m_bestScore;
}
bool operator<(const ChartCellLabel &other) const {
// m_coverage and m_label uniquely identify a ChartCellLabel, so don't
// need to compare m_stack.
@ -81,6 +91,7 @@ private:
const WordsRange &m_coverage;
const Word &m_label;
Stack m_stack;
mutable float m_bestScore;
};
}

View File

@ -12,6 +12,7 @@ class WordsRange;
class TargetPhrase;
class InputPath;
class InputType;
class ChartCellLabel;
class ChartParserCallback
{
@ -26,7 +27,7 @@ public:
virtual void Evaluate(const InputType &input, const InputPath &inputPath) = 0;
virtual float CalcEstimateOfBestScore(const TargetPhraseCollection &, const StackVec &) const = 0;
virtual float GetBestScore(const ChartCellLabel *chartCell) const = 0;
};

View File

@ -75,7 +75,11 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc,
}
}
float score = ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec);
const TargetPhrase &targetPhrase = **(tpc.begin());
float score = targetPhrase.GetFutureScore();
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) {
score += (*p)->GetBestScore(this);
}
// If the rule limit has already been reached then don't add the option
// unless it is better than at least one existing option.
@ -155,6 +159,15 @@ void ChartTranslationOptionList::ApplyThreshold()
m_size = std::distance(m_collection.begin(), bound);
}
float ChartTranslationOptionList::GetBestScore(const ChartCellLabel *chartCell) const
{
const HypoList *stack = chartCell->GetStack().cube;
assert(stack);
assert(!stack->empty());
const ChartHypothesis &bestHypo = **(stack->begin());
return bestHypo.GetTotalScore();
}
void ChartTranslationOptionList::Evaluate(const InputType &input, const InputPath &inputPath)
{
// NEVER iterate over ALL of the collection. Just over the first m_size

View File

@ -32,6 +32,7 @@ class TargetPhraseCollection;
class WordsRange;
class InputType;
class InputPath;
class ChartCellLabel;
//! a vector of translations options for a specific range, in a specific sentence
class ChartTranslationOptionList : public ChartParserCallback
@ -60,9 +61,7 @@ public:
return m_size == 0;
}
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const {
return ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec);
}
float GetBestScore(const ChartCellLabel *chartCell) const;
void Clear();
void ApplyThreshold();

View File

@ -51,23 +51,6 @@ ChartTranslationOptions::~ChartTranslationOptions()
}
float ChartTranslationOptions::CalcEstimateOfBestScore(
const TargetPhraseCollection &tpc,
const StackVec &stackVec)
{
const TargetPhrase &targetPhrase = **(tpc.begin());
float estimateOfBestScore = targetPhrase.GetFutureScore();
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end();
++p) {
const HypoList *stack = (*p)->GetStack().cube;
assert(stack);
assert(!stack->empty());
const ChartHypothesis &bestHypo = **(stack->begin());
estimateOfBestScore += bestHypo.GetTotalScore();
}
return estimateOfBestScore;
}
void ChartTranslationOptions::Evaluate(const InputType &input, const InputPath &inputPath)
{
SetInputPath(&inputPath);

View File

@ -83,7 +83,7 @@ public:
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const;
float GetBestScore(const ChartCellLabel *chartCell) const;
bool Empty() const {
return edges_.Empty();
@ -124,8 +124,7 @@ template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targe
float below_score = 0.0;
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
vertices.push_back((*i)->GetStack().incr->RootAlternate());
if (vertices.back().Empty()) return;
below_score += vertices.back().Bound();
below_score += (*i)->GetBestScore(this);
}
std::vector<lm::WordIndex> words;
@ -178,15 +177,12 @@ template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std:
edges_.AddEdge(edge);
}
// for early pruning
template <class Model> float Fill<Model>::CalcEstimateOfBestScore(const TargetPhraseCollection &targets, const StackVec &nts) const
// for pruning
template <class Model> float Fill<Model>::GetBestScore(const ChartCellLabel *chartCell) const
{
float below_score = 0.0;
for (StackVec::const_iterator i = nts.begin(); i != nts.end(); ++i) {
below_score += (*i)->GetStack().incr->RootAlternate().Bound();
}
const TargetPhrase &targetPhrase = **(targets.begin());
return targetPhrase.GetFutureScore() + below_score;
search::PartialVertex vertex = chartCell->GetStack().incr->RootAlternate();
UTIL_THROW_IF2(vertex.Empty(), "hypothesis with empty stack");
return vertex.Bound();
}
// TODO: factors (but chart doesn't seem to support factors anyway).

View File

@ -41,7 +41,12 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc,
return;
}
const float score = outColl.CalcEstimateOfBestScore(tpc, stackVec);
const TargetPhrase &targetPhrase = **(tpc.begin());
float score = targetPhrase.GetFutureScore();
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) {
float stackScore = (*p)->GetBestScore(&outColl);
score += stackScore;
}
// If the rule limit has already been reached then don't add the option
// unless it is better than at least one existing option.