mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-25 04:43:03 +03:00
faster pruning in chart decoding
This commit is contained in:
parent
c8682e9420
commit
1f435340f0
@ -22,6 +22,7 @@
|
||||
#include "HypoList.h"
|
||||
#include "Word.h"
|
||||
#include "WordsRange.h"
|
||||
#include "ChartParserCallback.h"
|
||||
|
||||
namespace search
|
||||
{
|
||||
@ -52,7 +53,8 @@ public:
|
||||
Stack stack=Stack())
|
||||
: m_coverage(coverage)
|
||||
, m_label(label)
|
||||
, m_stack(stack) {
|
||||
, m_stack(stack)
|
||||
, m_bestScore(0) {
|
||||
}
|
||||
|
||||
const WordsRange &GetCoverage() const {
|
||||
@ -68,6 +70,14 @@ public:
|
||||
return m_stack;
|
||||
}
|
||||
|
||||
//caching of best score on stack
|
||||
float GetBestScore(const ChartParserCallback *outColl) const {
|
||||
if (m_bestScore == 0) {
|
||||
m_bestScore = outColl->GetBestScore(this);
|
||||
}
|
||||
return m_bestScore;
|
||||
}
|
||||
|
||||
bool operator<(const ChartCellLabel &other) const {
|
||||
// m_coverage and m_label uniquely identify a ChartCellLabel, so don't
|
||||
// need to compare m_stack.
|
||||
@ -81,6 +91,7 @@ private:
|
||||
const WordsRange &m_coverage;
|
||||
const Word &m_label;
|
||||
Stack m_stack;
|
||||
mutable float m_bestScore;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ class WordsRange;
|
||||
class TargetPhrase;
|
||||
class InputPath;
|
||||
class InputType;
|
||||
class ChartCellLabel;
|
||||
|
||||
class ChartParserCallback
|
||||
{
|
||||
@ -26,7 +27,7 @@ public:
|
||||
|
||||
virtual void Evaluate(const InputType &input, const InputPath &inputPath) = 0;
|
||||
|
||||
virtual float CalcEstimateOfBestScore(const TargetPhraseCollection &, const StackVec &) const = 0;
|
||||
virtual float GetBestScore(const ChartCellLabel *chartCell) const = 0;
|
||||
|
||||
};
|
||||
|
||||
|
@ -75,7 +75,11 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc,
|
||||
}
|
||||
}
|
||||
|
||||
float score = ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec);
|
||||
const TargetPhrase &targetPhrase = **(tpc.begin());
|
||||
float score = targetPhrase.GetFutureScore();
|
||||
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) {
|
||||
score += (*p)->GetBestScore(this);
|
||||
}
|
||||
|
||||
// If the rule limit has already been reached then don't add the option
|
||||
// unless it is better than at least one existing option.
|
||||
@ -155,6 +159,15 @@ void ChartTranslationOptionList::ApplyThreshold()
|
||||
m_size = std::distance(m_collection.begin(), bound);
|
||||
}
|
||||
|
||||
float ChartTranslationOptionList::GetBestScore(const ChartCellLabel *chartCell) const
|
||||
{
|
||||
const HypoList *stack = chartCell->GetStack().cube;
|
||||
assert(stack);
|
||||
assert(!stack->empty());
|
||||
const ChartHypothesis &bestHypo = **(stack->begin());
|
||||
return bestHypo.GetTotalScore();
|
||||
}
|
||||
|
||||
void ChartTranslationOptionList::Evaluate(const InputType &input, const InputPath &inputPath)
|
||||
{
|
||||
// NEVER iterate over ALL of the collection. Just over the first m_size
|
||||
|
@ -32,6 +32,7 @@ class TargetPhraseCollection;
|
||||
class WordsRange;
|
||||
class InputType;
|
||||
class InputPath;
|
||||
class ChartCellLabel;
|
||||
|
||||
//! a vector of translations options for a specific range, in a specific sentence
|
||||
class ChartTranslationOptionList : public ChartParserCallback
|
||||
@ -60,9 +61,7 @@ public:
|
||||
return m_size == 0;
|
||||
}
|
||||
|
||||
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const {
|
||||
return ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec);
|
||||
}
|
||||
float GetBestScore(const ChartCellLabel *chartCell) const;
|
||||
|
||||
void Clear();
|
||||
void ApplyThreshold();
|
||||
|
@ -51,23 +51,6 @@ ChartTranslationOptions::~ChartTranslationOptions()
|
||||
|
||||
}
|
||||
|
||||
float ChartTranslationOptions::CalcEstimateOfBestScore(
|
||||
const TargetPhraseCollection &tpc,
|
||||
const StackVec &stackVec)
|
||||
{
|
||||
const TargetPhrase &targetPhrase = **(tpc.begin());
|
||||
float estimateOfBestScore = targetPhrase.GetFutureScore();
|
||||
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end();
|
||||
++p) {
|
||||
const HypoList *stack = (*p)->GetStack().cube;
|
||||
assert(stack);
|
||||
assert(!stack->empty());
|
||||
const ChartHypothesis &bestHypo = **(stack->begin());
|
||||
estimateOfBestScore += bestHypo.GetTotalScore();
|
||||
}
|
||||
return estimateOfBestScore;
|
||||
}
|
||||
|
||||
void ChartTranslationOptions::Evaluate(const InputType &input, const InputPath &inputPath)
|
||||
{
|
||||
SetInputPath(&inputPath);
|
||||
|
@ -83,7 +83,7 @@ public:
|
||||
|
||||
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
|
||||
|
||||
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const;
|
||||
float GetBestScore(const ChartCellLabel *chartCell) const;
|
||||
|
||||
bool Empty() const {
|
||||
return edges_.Empty();
|
||||
@ -124,8 +124,7 @@ template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targe
|
||||
float below_score = 0.0;
|
||||
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
|
||||
vertices.push_back((*i)->GetStack().incr->RootAlternate());
|
||||
if (vertices.back().Empty()) return;
|
||||
below_score += vertices.back().Bound();
|
||||
below_score += (*i)->GetBestScore(this);
|
||||
}
|
||||
|
||||
std::vector<lm::WordIndex> words;
|
||||
@ -178,15 +177,12 @@ template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std:
|
||||
edges_.AddEdge(edge);
|
||||
}
|
||||
|
||||
// for early pruning
|
||||
template <class Model> float Fill<Model>::CalcEstimateOfBestScore(const TargetPhraseCollection &targets, const StackVec &nts) const
|
||||
// for pruning
|
||||
template <class Model> float Fill<Model>::GetBestScore(const ChartCellLabel *chartCell) const
|
||||
{
|
||||
float below_score = 0.0;
|
||||
for (StackVec::const_iterator i = nts.begin(); i != nts.end(); ++i) {
|
||||
below_score += (*i)->GetStack().incr->RootAlternate().Bound();
|
||||
}
|
||||
const TargetPhrase &targetPhrase = **(targets.begin());
|
||||
return targetPhrase.GetFutureScore() + below_score;
|
||||
search::PartialVertex vertex = chartCell->GetStack().incr->RootAlternate();
|
||||
UTIL_THROW_IF2(vertex.Empty(), "hypothesis with empty stack");
|
||||
return vertex.Bound();
|
||||
}
|
||||
|
||||
// TODO: factors (but chart doesn't seem to support factors anyway).
|
||||
|
@ -41,7 +41,12 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc,
|
||||
return;
|
||||
}
|
||||
|
||||
const float score = outColl.CalcEstimateOfBestScore(tpc, stackVec);
|
||||
const TargetPhrase &targetPhrase = **(tpc.begin());
|
||||
float score = targetPhrase.GetFutureScore();
|
||||
for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) {
|
||||
float stackScore = (*p)->GetBestScore(&outColl);
|
||||
score += stackScore;
|
||||
}
|
||||
|
||||
// If the rule limit has already been reached then don't add the option
|
||||
// unless it is better than at least one existing option.
|
||||
|
Loading…
Reference in New Issue
Block a user