Changes to chart decoder cube pruning: create one cube per dotted rule

instead of one per translation and do 'non-lazy' scoring, i.e. fully
score the corner and neighbor hypotheses inside the rule cube instead
of waiting until an item is popped.  The old behaviour -- faster but
with more search errors -- is available via the
cube-pruning-lazy-scoring option.

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4039 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
pjwilliams 2011-06-27 15:13:15 +00:00
parent c7cc79a20e
commit 2451371ca2
27 changed files with 682 additions and 437 deletions

View File

@ -144,7 +144,7 @@ Moses::TargetPhraseCollection *TargetPhraseCollection::ConvertToMoses(const std:
ret->Add(mosesPhrase);
}
ret->Prune(true, phraseDict.GetTableLimit());
ret->Sort(true, phraseDict.GetTableLimit());
return ret;

View File

@ -43,6 +43,10 @@ POSSIBILITY OF SUCH DAMAGE.
#include "PhraseDictionary.h"
#include "ChartTrellisPathList.h"
#include "ChartTrellisPath.h"
#include "ChartTranslationOption.h"
#include "ChartHypothesis.h"
#include "CoveredChartSpan.h"
using namespace std;
using namespace Moses;
@ -223,7 +227,11 @@ void OutputTranslationOptions(std::ostream &out, const ChartHypothesis *hypo, lo
{
// recursive
if (hypo != NULL) {
out << "Trans Opt " << translationId << " " << hypo->GetCurrSourceRange() << ": " << hypo->GetTranslationOption()
out << "Trans Opt " << translationId
<< " " << hypo->GetCurrSourceRange()
<< ": " << hypo->GetTranslationOption().GetLastCoveredChartSpan()
<< ": " << hypo->GetCurrTargetPhrase().GetTargetLHS()
<< "->" << hypo->GetCurrTargetPhrase()
<< " " << hypo->GetTotalScore() << hypo->GetScoreBreakdown()
<< endl;
}

View File

@ -32,7 +32,6 @@
#include "ChartTranslationOptionList.h"
using namespace std;
using namespace Moses;
namespace Moses
{
@ -83,14 +82,14 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList
const StaticData &staticData = StaticData::Instance();
// priority queue for applicable rules with selected hypotheses
RuleCubeQueue queue;
RuleCubeQueue queue(m_manager);
// add all trans opt into queue. using only 1st child node.
ChartTranslationOptionList::const_iterator iterList;
for (iterList = transOptList.begin(); iterList != transOptList.end(); ++iterList)
{
const ChartTranslationOption &transOpt = **iterList;
RuleCube *ruleCube = new RuleCube(transOpt, allChartCells);
RuleCube *ruleCube = new RuleCube(transOpt, allChartCells, m_manager);
queue.Add(ruleCube);
}
@ -98,16 +97,8 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList
const size_t popLimit = staticData.GetCubePruningPopLimit();
for (size_t numPops = 0; numPops < popLimit && !queue.IsEmpty(); ++numPops)
{
RuleCube *ruleCube = queue.Pop();
// create hypothesis from RuleCube
ChartHypothesis *hypo = new ChartHypothesis(*ruleCube, m_manager);
assert(hypo);
hypo->CalcScore();
ChartHypothesis *hypo = queue.Pop();
AddHypothesis(hypo);
// add neighbors to the queue
ruleCube->CreateNeighbors(queue);
}
}
@ -223,5 +214,3 @@ std::ostream& operator<<(std::ostream &out, const ChartCell &cell)
}
} // namespace

View File

@ -22,7 +22,7 @@
#include <algorithm>
#include <vector>
#include "ChartHypothesis.h"
#include "RuleCube.h"
#include "RuleCubeItem.h"
#include "ChartCell.h"
#include "ChartManager.h"
#include "TargetPhrase.h"
@ -33,9 +33,6 @@
#include "ChartTranslationOption.h"
#include "FFState.h"
using namespace std;
using namespace Moses;
namespace Moses
{
unsigned int ChartHypothesis::s_HypothesesCreated = 0;
@ -45,32 +42,33 @@ ObjectPool<ChartHypothesis> ChartHypothesis::s_objectPool("ChartHypothesis", 300
#endif
/** Create a hypothesis from a rule */
ChartHypothesis::ChartHypothesis(const RuleCube &ruleCube, ChartManager &manager)
:m_transOpt(ruleCube.GetTranslationOption())
,m_id(++s_HypothesesCreated)
,m_currSourceWordsRange(ruleCube.GetTranslationOption().GetSourceWordsRange())
,m_ffStates(manager.GetTranslationSystem()->GetStatefulFeatureFunctions().size())
ChartHypothesis::ChartHypothesis(const ChartTranslationOption &transOpt,
const RuleCubeItem &item,
ChartManager &manager)
:m_id(++s_HypothesesCreated)
,m_targetPhrase(*(item.GetTranslationDimension().GetTargetPhrase()))
,m_transOpt(transOpt)
,m_contextPrefix(Output, manager.GetTranslationSystem()->GetLanguageModels().GetMaxNGramOrder())
,m_contextSuffix(Output, manager.GetTranslationSystem()->GetLanguageModels().GetMaxNGramOrder())
,m_currSourceWordsRange(transOpt.GetSourceWordsRange())
,m_ffStates(manager.GetTranslationSystem()->GetStatefulFeatureFunctions().size())
,m_arcList(NULL)
,m_winningHypo(NULL)
,m_manager(manager)
{
//TRACE_ERR(m_targetPhrase << endl);
// underlying hypotheses for sub-spans
m_numTargetTerminals = GetCurrTargetPhrase().GetNumTerminals();
const std::vector<RuleCubeDimension> &childEntries = ruleCube.GetCube();
const std::vector<HypothesisDimension> &childEntries = item.GetHypothesisDimensions();
// ... are stored
assert(m_prevHypos.empty());
m_prevHypos.reserve(childEntries.size());
vector<RuleCubeDimension>::const_iterator iter;
std::vector<HypothesisDimension>::const_iterator iter;
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter)
{
const RuleCubeDimension &ruleCubeDimension = *iter;
const ChartHypothesis *prevHypo = ruleCubeDimension.GetHypothesis();
const HypothesisDimension &dimension = *iter;
const ChartHypothesis *prevHypo = dimension.GetHypothesis();
// keep count of words (= length of generated string)
m_numTargetTerminals += prevHypo->GetNumTargetTerminals();
@ -179,8 +177,8 @@ size_t ChartHypothesis::CalcSuffix(Phrase &ret, size_t size) const
// special handling for small hypotheses
// does the prefix match the entire hypothesis string? -> just copy prefix
if (m_contextPrefix.GetSize() == m_numTargetTerminals) {
size_t maxCount = min(m_contextPrefix.GetSize(), size)
, pos = m_contextPrefix.GetSize() - 1;
size_t maxCount = std::min(m_contextPrefix.GetSize(), size);
size_t pos= m_contextPrefix.GetSize() - 1;
for (size_t ind = 0; ind < maxCount; ++ind) {
const Word &word = m_contextPrefix.GetWord(pos);
@ -267,7 +265,7 @@ void ChartHypothesis::CalcScore()
// sfs[i]->ChartEvaluate(m_targetPhrase, &m_scoreBreakdown);
//}
const vector<const StatefulFeatureFunction*>& ffs =
const std::vector<const StatefulFeatureFunction*>& ffs =
m_manager.GetTranslationSystem()->GetStatefulFeatureFunctions();
for (unsigned i = 0; i < ffs.size(); ++i) {
m_ffStates[i] = ffs[i]->EvaluateChart(*this,i,&m_scoreBreakdown);
@ -361,7 +359,7 @@ void ChartHypothesis::SetWinningHypo(const ChartHypothesis *hypo)
TO_STRING_BODY(ChartHypothesis)
// friend
ostream& operator<<(ostream& out, const ChartHypothesis& hypo)
std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
{
out << hypo.GetId();
@ -392,4 +390,3 @@ ostream& operator<<(ostream& out, const ChartHypothesis& hypo)
}
}

View File

@ -31,9 +31,10 @@
namespace Moses
{
class RuleCube;
class ChartHypothesis;
class ChartManager;
class RuleCubeItem;
typedef std::vector<ChartHypothesis*> ChartArcList;
@ -50,6 +51,7 @@ protected:
static unsigned int s_HypothesesCreated;
int m_id; /**< numeric ID of this hypothesis, used for logging */
const TargetPhrase &m_targetPhrase;
const ChartTranslationOption &m_transOpt;
Phrase m_contextPrefix, m_contextSuffix;
@ -97,7 +99,9 @@ public:
}
#endif
explicit ChartHypothesis(const RuleCube &ruleCube, ChartManager &manager);
ChartHypothesis(const ChartTranslationOption &, const RuleCubeItem &item,
ChartManager &manager);
~ChartHypothesis();
int GetId()const {
@ -107,7 +111,7 @@ public:
return m_transOpt;
}
const TargetPhrase &GetCurrTargetPhrase()const {
return m_transOpt.GetTargetPhrase();
return m_targetPhrase;
}
const WordsRange &GetCurrSourceRange()const {
return m_currSourceWordsRange;

View File

@ -268,8 +268,10 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
}
assert(targetPhraseCollection);
outColl.Add(*targetPhraseCollection, *coveredChartSpan,
GetCellCollection(), adhereTableLimit, rulesLimit);
if (!targetPhraseCollection->IsEmpty()) {
outColl.Add(*targetPhraseCollection, *coveredChartSpan,
GetCellCollection(), adhereTableLimit, rulesLimit);
}
} // if (node)

View File

@ -19,12 +19,12 @@
***********************************************************************/
#include "ChartTranslationOption.h"
#include "TargetPhrase.h"
#include "AlignmentInfo.h"
#include "CoveredChartSpan.h"
#include "ChartCellCollection.h"
using namespace std;
#include "AlignmentInfo.h"
#include "ChartCellCollection.h"
#include "CoveredChartSpan.h"
#include <vector>
namespace Moses
{
@ -54,18 +54,11 @@ void ChartTranslationOption::CalcEstimateOfBestScore(
assert(!childCell.GetSortedHypotheses(nonTerm).empty());
// create a list of hypotheses that match the non-terminal
const vector<const ChartHypothesis *> &stack =
const std::vector<const ChartHypothesis *> &stack =
childCell.GetSortedHypotheses(nonTerm);
const ChartHypothesis *hypo = stack[0];
m_estimateOfBestScore += hypo->GetTotalScore();
}
}
std::ostream& operator<<(std::ostream &out, const ChartTranslationOption &rule)
{
out << rule.m_lastCoveredChartSpan << ": " << rule.m_targetPhrase.GetTargetLHS() << "->" << rule.m_targetPhrase;
return out;
}
} // namespace

View File

@ -20,70 +20,68 @@
#pragma once
#include "TargetPhrase.h"
#include "TargetPhraseCollection.h"
#include "WordsRange.h"
#include <cassert>
#include <vector>
#include "Word.h"
#include "WordsRange.h"
#include "TargetPhrase.h"
namespace Moses
{
class CoveredChartSpan;
class ChartCellCollection;
// basically a phrase translation and the vector of words consumed to map each word
// Similar to a DottedRule, but contains a direct reference to a list
// of translations and provdes an estimate of the best score.
class ChartTranslationOption
{
friend std::ostream& operator<<(std::ostream&, const ChartTranslationOption&);
protected:
const TargetPhrase &m_targetPhrase;
const CoveredChartSpan &m_lastCoveredChartSpan;
/* map each source word in the phrase table to:
1. a word in the input sentence, if the pt word is a terminal
2. a 1+ phrase in the input sentence, if the pt word is a non-terminal
*/
const WordsRange &m_wordsRange;
float m_estimateOfBestScore;
ChartTranslationOption &operator=(const ChartTranslationOption &); // not implemented
void CalcEstimateOfBestScore(const CoveredChartSpan *, const ChartCellCollection &);
public:
ChartTranslationOption(const TargetPhrase &targetPhrase, const CoveredChartSpan &lastCoveredChartSpan, const WordsRange &wordsRange, const ChartCellCollection &allChartCells)
:m_targetPhrase(targetPhrase)
,m_lastCoveredChartSpan(lastCoveredChartSpan)
,m_wordsRange(wordsRange)
,m_estimateOfBestScore(m_targetPhrase.GetFutureScore())
public:
ChartTranslationOption(const TargetPhraseCollection &targetPhraseColl,
const CoveredChartSpan &lastCoveredChartSpan,
const WordsRange &wordsRange,
const ChartCellCollection &allChartCells)
: m_lastCoveredChartSpan(lastCoveredChartSpan)
, m_targetPhraseCollection(targetPhraseColl)
, m_wordsRange(wordsRange)
, m_estimateOfBestScore(0)
{
const TargetPhrase &targetPhrase = **(m_targetPhraseCollection.begin());
m_estimateOfBestScore = targetPhrase.GetFutureScore();
CalcEstimateOfBestScore(&m_lastCoveredChartSpan, allChartCells);
}
~ChartTranslationOption()
{}
const TargetPhrase &GetTargetPhrase() const {
return m_targetPhrase;
}
~ChartTranslationOption() {}
const CoveredChartSpan &GetLastCoveredChartSpan() const {
return m_lastCoveredChartSpan;
}
const TargetPhraseCollection &GetTargetPhraseCollection() const {
return m_targetPhraseCollection;
}
const WordsRange &GetSourceWordsRange() const {
return m_wordsRange;
}
// return an estimate of the best score possible with this translation option.
// the estimate is the sum of the target phrase's estimated score plus the
// scores of the best child hypotheses. (the same as the ordering criterion
// currently used in RuleCubeQueue.)
inline float GetEstimateOfBestScore() const {
return m_estimateOfBestScore;
}
// the estimate is the sum of the top target phrase's estimated score plus the
// scores of the best child hypotheses.
inline float GetEstimateOfBestScore() const { return m_estimateOfBestScore; }
private:
// not implemented
ChartTranslationOption &operator=(const ChartTranslationOption &);
void CalcEstimateOfBestScore(const CoveredChartSpan *,
const ChartCellCollection &);
const CoveredChartSpan &m_lastCoveredChartSpan;
const TargetPhraseCollection &m_targetPhraseCollection;
const WordsRange &m_wordsRange;
float m_estimateOfBestScore;
};
}

View File

@ -30,7 +30,6 @@
#include "Util.h"
using namespace std;
using namespace Moses;
namespace Moses
{
@ -43,8 +42,8 @@ ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType con
,m_system(system)
,m_decodeGraphList(system->GetDecodeGraphs())
,m_hypoStackColl(hypoStackColl)
,m_collection(source.GetSize())
,m_ruleLookupManagers(ruleLookupManagers)
,m_collection(source.GetSize())
{
// create 2-d vector
size_t size = source.GetSize();
@ -59,7 +58,7 @@ ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType con
ChartTranslationOptionCollection::~ChartTranslationOptionCollection()
{
RemoveAllInColl(m_unksrcs);
RemoveAllInColl(m_cacheTargetPhrase);
RemoveAllInColl(m_cacheTargetPhraseCollection);
std::list<std::vector<CoveredChartSpan*>* >::iterator iterOuter;
for (iterOuter = m_coveredChartSpanCache.begin(); iterOuter != m_coveredChartSpanCache.end(); ++iterOuter) {
@ -225,8 +224,10 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
// add to dictionary
TargetPhrase *targetPhrase = new TargetPhrase(Output);
TargetPhraseCollection *tpc = new TargetPhraseCollection;
tpc->Add(targetPhrase);
m_cacheTargetPhrase.push_back(targetPhrase);
m_cacheTargetPhraseCollection.push_back(tpc);
Word &targetWord = targetPhrase->AddWord();
targetWord.CreateUnknownWord(sourceWord);
@ -240,7 +241,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
targetPhrase->SetTargetLHS(targetLHS);
// chart rule
ChartTranslationOption *chartRule = new ChartTranslationOption(*targetPhrase
ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc
, *coveredChartSpanList->back()
, range
, m_hypoStackColl);
@ -251,6 +252,8 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
vector<float> unknownScore(1, FloorScore(-numeric_limits<float>::infinity()));
TargetPhrase *targetPhrase = new TargetPhrase(Output);
TargetPhraseCollection *tpc = new TargetPhraseCollection;
tpc->Add(targetPhrase);
// loop
const UnknownLHSList &lhsList = staticData.GetUnknownLHS();
UnknownLHSList::const_iterator iterLHS;
@ -262,7 +265,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
targetLHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), targetLHSStr, true);
assert(targetLHS.GetFactor(0) != NULL);
m_cacheTargetPhrase.push_back(targetPhrase);
m_cacheTargetPhraseCollection.push_back(tpc);
targetPhrase->SetSourcePhrase(m_unksrc);
targetPhrase->SetScore(unknownWordPenaltyProducer, unknownScore);
targetPhrase->SetTargetLHS(targetLHS);
@ -274,7 +277,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
// chart rule
assert(coveredChartSpanList->size());
ChartTranslationOption *chartRule = new ChartTranslationOption(*targetPhrase
ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc
, *coveredChartSpanList->back()
, range
, m_hypoStackColl);
@ -302,7 +305,4 @@ void ChartTranslationOptionCollection::Sort(size_t startPos, size_t endPos)
list.Sort();
}
} // namespace

View File

@ -40,15 +40,15 @@ class ChartTranslationOptionCollection
{
friend std::ostream& operator<<(std::ostream&, const ChartTranslationOptionCollection&);
protected:
const InputType &m_source;
const InputType &m_source;
const TranslationSystem* m_system;
std::vector <DecodeGraph*> m_decodeGraphList;
const ChartCellCollection &m_hypoStackColl;
const std::vector<ChartRuleLookupManager*> &m_ruleLookupManagers;
std::vector< std::vector< ChartTranslationOptionList > > m_collection; /*< contains translation options */
std::vector< std::vector< ChartTranslationOptionList > > m_collection; /*< contains translation options */
std::vector<Phrase*> m_unksrcs;
std::list<TargetPhrase*> m_cacheTargetPhrase;
std::list<TargetPhraseCollection*> m_cacheTargetPhraseCollection;
std::list<std::vector<CoveredChartSpan*>* > m_coveredChartSpanCache;
// for adding 1 trans opt in unknown word proc

View File

@ -26,11 +26,9 @@
#include "ChartCellCollection.h"
#include "WordsRange.h"
using namespace std;
using namespace Moses;
namespace Moses
{
#ifdef USE_HYPO_POOL
ObjectPool<ChartTranslationOptionList> ChartTranslationOptionList::s_objectPool("ChartTranslationOptionList", 3000);
#endif
@ -61,48 +59,44 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &targetPhraseC
, bool /* adhereTableLimit */
, size_t ruleLimit)
{
TargetPhraseCollection::const_iterator iter;
TargetPhraseCollection::const_iterator iterEnd = targetPhraseCollection.end();
if (targetPhraseCollection.IsEmpty()) {
return;
}
for (iter = targetPhraseCollection.begin(); iter != iterEnd; ++iter) {
const TargetPhrase &targetPhrase = **iter;
if (m_collection.size() < ruleLimit) {
// not yet filled out quota. add everything
ChartTranslationOption *option = new ChartTranslationOption(
targetPhrase, coveredChartSpan, m_range, chartCellColl);
m_collection.push_back(option);
float score = option->GetEstimateOfBestScore();
m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold;
}
else {
// full but not bursting. add if better than worst score
ChartTranslationOption option(targetPhrase, coveredChartSpan, m_range,
chartCellColl);
float score = option.GetEstimateOfBestScore();
if (score > m_scoreThreshold) {
// dynamic allocation deferred until here on the assumption that most
// options will score below the threshold.
m_collection.push_back(new ChartTranslationOption(option));
}
if (m_collection.size() < ruleLimit) {
// not yet filled out quota. add everything
ChartTranslationOption *option = new ChartTranslationOption(
targetPhraseCollection, coveredChartSpan, m_range, chartCellColl);
m_collection.push_back(option);
float score = option->GetEstimateOfBestScore();
m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold;
}
else {
// full but not bursting. add if better than worst score
ChartTranslationOption option(targetPhraseCollection, coveredChartSpan,
m_range, chartCellColl);
float score = option.GetEstimateOfBestScore();
if (score > m_scoreThreshold) {
// dynamic allocation deferred until here on the assumption that most
// options will score below the threshold.
m_collection.push_back(new ChartTranslationOption(option));
}
}
// prune if bursting
if (m_collection.size() > ruleLimit * 2) {
std::nth_element(m_collection.begin()
, m_collection.begin() + ruleLimit
, m_collection.end()
, ChartTranslationOptionOrderer());
// delete the bottom half
for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) {
// make the best score of bottom half the score threshold
float score = m_collection[ind]->GetEstimateOfBestScore();
m_scoreThreshold = (score > m_scoreThreshold) ? score : m_scoreThreshold;
delete m_collection[ind];
}
m_collection.resize(ruleLimit);
// prune if bursting
if (m_collection.size() > ruleLimit * 2) {
std::nth_element(m_collection.begin()
, m_collection.begin() + ruleLimit
, m_collection.end()
, ChartTranslationOptionOrderer());
// delete the bottom half
for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) {
// make the best score of bottom half the score threshold
float score = m_collection[ind]->GetEstimateOfBestScore();
m_scoreThreshold = (score > m_scoreThreshold) ? score : m_scoreThreshold;
delete m_collection[ind];
}
m_collection.resize(ruleLimit);
}
}
@ -157,15 +151,4 @@ void ChartTranslationOptionList::Sort()
std::sort(m_collection.begin(), m_collection.end(), ChartTranslationOptionOrderer());
}
std::ostream& operator<<(std::ostream &out, const ChartTranslationOptionList &coll)
{
ChartTranslationOptionList::const_iterator iter;
for (iter = coll.begin() ; iter != coll.end() ; ++iter) {
const ChartTranslationOption &rule = **iter;
out << rule << endl;
}
return out;
}
}

View File

@ -21,6 +21,7 @@
#include "ChartTrellisNode.h"
#include "ChartHypothesis.h"
#include "CoveredChartSpan.h"
#include "ScoreComponentCollection.h"
#include "StaticData.h"
@ -103,8 +104,8 @@ Phrase ChartTrellisNode::GetOutputPhrase() const
const ChartTranslationOption &transOpt = m_hypo->GetTranslationOption();
VERBOSE(3, "Trans Opt:" << transOpt << std::endl);
VERBOSE(3, "Trans Opt:" << transOpt.GetLastCoveredChartSpan() << ": " << m_hypo->GetCurrTargetPhrase().GetTargetLHS() << "->" << m_hypo->GetCurrTargetPhrase() << std::endl);
const Phrase &currTargetPhrase = m_hypo->GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
m_hypo->GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap();

View File

@ -92,8 +92,9 @@ libmoses_la_HEADERS = \
PrefixTreeMap.h \
ReorderingConstraint.h \
ReorderingStack.h \
RuleCube.h \
RuleCubeQueue.h \
RuleCube.h \
RuleCubeItem.h \
RuleCubeQueue.h \
ScoreComponentCollection.h \
ScoreIndexManager.h \
ScoreProducer.h \
@ -245,8 +246,9 @@ libmoses_la_SOURCES = \
PrefixTreeMap.cpp \
ReorderingConstraint.cpp \
ReorderingStack.cpp \
RuleCube.cpp \
RuleCubeQueue.cpp \
RuleCube.cpp \
RuleCubeItem.cpp \
RuleCubeQueue.cpp \
ScoreComponentCollection.cpp \
ScoreIndexManager.cpp \
ScoreProducer.cpp \

View File

@ -124,6 +124,7 @@ Parameter::Parameter()
#endif
AddParam("cube-pruning-pop-limit", "cbp", "How many hypotheses should be popped for each stack. (default = 1000)");
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
AddParam("cube-pruning-lazy-scoring", "cbls", "Don't fully score a hypothesis until it is popped");
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
AddParam("use-alignment-info", "Use word-to-word alignment: actually it is only used to output the word-to-word alignment. Word-to-word alignments are taken from the phrase table if any. Default is false.");

View File

@ -46,6 +46,22 @@ void PhraseDictionaryNodeSCFG::Prune(size_t tableLimit)
m_targetPhraseCollection->Prune(true, tableLimit);
}
void PhraseDictionaryNodeSCFG::Sort(size_t tableLimit)
{
// recusively sort
for (TerminalMap::iterator p = m_sourceTermMap.begin(); p != m_sourceTermMap.end(); ++p) {
p->second.Sort(tableLimit);
}
for (NonTerminalMap::iterator p = m_nonTermMap.begin(); p != m_nonTermMap.end(); ++p) {
p->second.Sort(tableLimit);
}
// prune TargetPhraseCollection in this node
if (m_targetPhraseCollection != NULL) {
m_targetPhraseCollection->Sort(true, tableLimit);
}
}
PhraseDictionaryNodeSCFG *PhraseDictionaryNodeSCFG::GetOrCreateChild(const Word &sourceTerm)
{
assert(!sourceTerm.IsNonTerminal());

View File

@ -170,6 +170,7 @@ public:
}
void Prune(size_t tableLimit);
void Sort(size_t tableLimit);
PhraseDictionaryNodeSCFG *GetOrCreateChild(const Word &sourceTerm);
PhraseDictionaryNodeSCFG *GetOrCreateChild(const Word &sourceNonTerm, const Word &targetNonTerm);
const PhraseDictionaryNodeSCFG *GetChild(const Word &sourceTerm) const;

View File

@ -176,7 +176,7 @@ bool PhraseDictionarySCFG::Load(const std::vector<FactorType> &input
// prune each target phrase collection
if (m_tableLimit) {
m_collection.Prune(m_tableLimit);
m_collection.Sort(m_tableLimit);
}
return true;

View File

@ -19,138 +19,87 @@
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "RuleCube.h"
#include "ChartCell.h"
#include "ChartTranslationOptionCollection.h"
#include "ChartCellCollection.h"
#include "RuleCubeQueue.h"
#include "WordsRange.h"
#include "ChartTranslationOption.h"
#include "Util.h"
#include "ChartTranslationOptionCollection.h"
#include "CoveredChartSpan.h"
#include "RuleCube.h"
#include "RuleCubeQueue.h"
#include "StaticData.h"
#include "Util.h"
#include "WordsRange.h"
#ifdef HAVE_BOOST
#include <boost/functional/hash.hpp>
#endif
using namespace std;
using namespace Moses;
namespace Moses
{
// create a cube for a rule
RuleCube::RuleCube(const ChartTranslationOption &transOpt
, const ChartCellCollection &allChartCells)
:m_transOpt(transOpt)
// initialise the RuleCube by creating the top-left corner item
RuleCube::RuleCube(const ChartTranslationOption &transOpt,
const ChartCellCollection &allChartCells,
ChartManager &manager)
: m_transOpt(transOpt)
{
const CoveredChartSpan *coveredChartSpan = &transOpt.GetLastCoveredChartSpan();
CreateRuleCubeDimension(coveredChartSpan, allChartCells);
CalcScore();
}
// for each non-terminal, create a ordered list of matching hypothesis from the chart
void RuleCube::CreateRuleCubeDimension(const CoveredChartSpan *coveredChartSpan, const ChartCellCollection &allChartCells)
{
// recurse through the linked list of source side non-terminals and terminals
const CoveredChartSpan *prevCoveredChartSpan = coveredChartSpan->GetPrevCoveredChartSpan();
if (prevCoveredChartSpan)
CreateRuleCubeDimension(prevCoveredChartSpan, allChartCells);
// only deal with non-terminals
if (coveredChartSpan->IsNonTerminal())
{
// get the essential information about the non-terminal
const WordsRange &childRange = coveredChartSpan->GetWordsRange(); // span covered by child
const ChartCell &childCell = allChartCells.Get(childRange); // list of all hypos for that span
const Word &nonTerm = coveredChartSpan->GetSourceWord(); // target (sic!) non-terminal label
// there have to be hypothesis with the desired non-terminal
// (otherwise the rule would not be considered)
assert(!childCell.GetSortedHypotheses(nonTerm).empty());
// create a list of hypotheses that match the non-terminal
RuleCubeDimension ruleCubeDimension(0, childCell.GetSortedHypotheses(nonTerm));
// add them to the vector for such lists
m_cube.push_back(ruleCubeDimension);
RuleCubeItem *item = new RuleCubeItem(transOpt, allChartCells);
m_covered.insert(item);
if (StaticData::Instance().GetCubePruningLazyScoring()) {
item->EstimateScore();
} else {
item->CreateHypothesis(transOpt, manager);
}
}
// create the RuleCube from an existing one, differing only in one child hypothesis
RuleCube::RuleCube(const RuleCube &copy, size_t ruleCubeDimensionIncr)
:m_transOpt(copy.m_transOpt)
,m_cube(copy.m_cube)
{
RuleCubeDimension &ruleCubeDimension = m_cube[ruleCubeDimensionIncr];
ruleCubeDimension.IncrementPos();
CalcScore();
m_queue.push(item);
}
RuleCube::~RuleCube()
{
//RemoveAllInColl(m_cube);
RemoveAllInColl(m_covered);
}
RuleCubeItem *RuleCube::Pop(ChartManager &manager)
{
RuleCubeItem *item = m_queue.top();
m_queue.pop();
CreateNeighbors(*item, manager);
return item;
}
// create new RuleCube for neighboring principle rules
// (duplicate detection is handled in RuleCubeQueue)
void RuleCube::CreateNeighbors(RuleCubeQueue &queue) const
void RuleCube::CreateNeighbors(const RuleCubeItem &item, ChartManager &manager)
{
// loop over all child hypotheses
for (size_t ind = 0; ind < m_cube.size(); ind++) {
const RuleCubeDimension &ruleCubeDimension = m_cube[ind];
// create neighbor along translation dimension
const TranslationDimension &translationDimension =
item.GetTranslationDimension();
if (translationDimension.HasMoreTranslations()) {
CreateNeighbor(item, -1, manager);
}
if (ruleCubeDimension.HasMoreHypo()) {
RuleCube *newEntry = new RuleCube(*this, ind);
queue.Add(newEntry);
// create neighbors along all hypothesis dimensions
for (size_t i = 0; i < item.GetHypothesisDimensions().size(); ++i) {
const HypothesisDimension &dimension = item.GetHypothesisDimensions()[i];
if (dimension.HasMoreHypo()) {
CreateNeighbor(item, i, manager);
}
}
}
// compute an estimated cost of the principle rule
// (consisting of rule translation scores plus child hypotheses scores)
void RuleCube::CalcScore()
void RuleCube::CreateNeighbor(const RuleCubeItem &item, int dimensionIndex,
ChartManager &manager)
{
m_combinedScore = m_transOpt.GetTargetPhrase().GetFutureScore();
for (size_t ind = 0; ind < m_cube.size(); ind++) {
const RuleCubeDimension &ruleCubeDimension = m_cube[ind];
const ChartHypothesis *hypo = ruleCubeDimension.GetHypothesis();
m_combinedScore += hypo->GetTotalScore();
RuleCubeItem *newItem = new RuleCubeItem(item, dimensionIndex);
std::pair<ItemSet::iterator, bool> result = m_covered.insert(newItem);
if (!result.second) {
delete newItem; // already seen it
} else {
if (StaticData::Instance().GetCubePruningLazyScoring()) {
newItem->EstimateScore();
} else {
newItem->CreateHypothesis(m_transOpt, manager);
}
m_queue.push(newItem);
}
}
bool RuleCube::operator<(const RuleCube &compare) const
{
if (&m_transOpt != &compare.m_transOpt)
return &m_transOpt < &compare.m_transOpt;
bool ret = m_cube < compare.m_cube;
return ret;
}
#ifdef HAVE_BOOST
std::size_t hash_value(const RuleCubeDimension & ruleCubeDimension)
{
boost::hash<const ChartHypothesis*> hasher;
return hasher(ruleCubeDimension.GetHypothesis());
}
#endif
std::ostream& operator<<(std::ostream &out, const RuleCubeDimension &ruleCubeDimension)
{
out << *ruleCubeDimension.GetHypothesis();
return out;
}
std::ostream& operator<<(std::ostream &out, const RuleCube &ruleCube)
{
out << ruleCube.GetTranslationOption() << endl;
std::vector<RuleCubeDimension>::const_iterator iter;
for (iter = ruleCube.GetCube().begin(); iter != ruleCube.GetCube().end(); ++iter) {
out << *iter << endl;
}
return out;
}
}

View File

@ -21,105 +21,119 @@
#pragma once
#include <vector>
#include <map>
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include "RuleCubeItem.h"
#ifdef HAVE_BOOST
#include <boost/functional/hash.hpp>
#include <boost/unordered_set.hpp>
#include <boost/version.hpp>
#endif
#include <cassert>
#include <queue>
#include <set>
#include <iostream>
#include "WordsRange.h"
#include "Word.h"
#include "ChartHypothesis.h"
#include <vector>
namespace Moses
{
class CoveredChartSpan;
class ChartTranslationOption;
extern bool g_debug;
class TranslationOptionCollection;
class TranslationOptionList;
class ChartCell;
class ChartCellCollection;
class RuleCube;
class RuleCubeQueue;
class ChartManager;
class ChartTranslationOption;
typedef std::vector<const ChartHypothesis*> HypoList;
// wrapper around list of hypothese for a particular non-term of a trans opt
class RuleCubeDimension
// Define an ordering between RuleCubeItems based on their scores. This
// is used to order items in the cube's priority queue.
class RuleCubeItemScoreOrderer
{
friend std::ostream& operator<<(std::ostream&, const RuleCubeDimension&);
protected:
size_t m_pos;
const HypoList *m_orderedHypos;
public:
RuleCubeDimension(size_t pos, const HypoList &orderedHypos)
:m_pos(pos)
,m_orderedHypos(&orderedHypos)
{}
size_t IncrementPos() {
return m_pos++;
}
bool HasMoreHypo() const {
return m_pos + 1 < m_orderedHypos->size();
}
const ChartHypothesis *GetHypothesis() const {
return (*m_orderedHypos)[m_pos];
}
//! transitive comparison used for adding objects into FactorCollection
bool operator<(const RuleCubeDimension &compare) const {
return GetHypothesis() < compare.GetHypothesis();
}
bool operator==(const RuleCubeDimension & compare) const {
return GetHypothesis() == compare.GetHypothesis();
public:
bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const {
return p->GetScore() < q->GetScore();
}
};
// Stores one dimension in the cube
// (all the hypotheses that match one non terminal)
// Define an ordering between RuleCubeItems based on their positions in the
// cube. This is used to record which positions in the cube have been covered
// during search.
class RuleCubeItemPositionOrderer
{
public:
bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const {
return *p < *q;
}
};
#ifdef HAVE_BOOST
class RuleCubeItemHasher
{
public:
size_t operator()(const RuleCubeItem *p) const {
size_t seed = 0;
boost::hash_combine(seed, p->GetHypothesisDimensions());
boost::hash_combine(seed, p->GetTranslationDimension().GetTargetPhrase());
return seed;
}
};
class RuleCubeItemEqualityPred
{
public:
bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const {
return p->GetHypothesisDimensions() == q->GetHypothesisDimensions() &&
p->GetTranslationDimension() == q->GetTranslationDimension();
}
};
#endif
class RuleCube
{
friend std::ostream& operator<<(std::ostream&, const RuleCube&);
protected:
const ChartTranslationOption &m_transOpt;
std::vector<RuleCubeDimension> m_cube;
public:
RuleCube(const ChartTranslationOption &, const ChartCellCollection &,
ChartManager &);
float m_combinedScore;
RuleCube(const RuleCube &copy, size_t ruleCubeDimensionIncr);
void CreateRuleCubeDimension(const CoveredChartSpan *coveredChartSpan, const ChartCellCollection &allChartCells);
void CalcScore();
public:
RuleCube(const ChartTranslationOption &transOpt
, const ChartCellCollection &allChartCells);
~RuleCube();
float GetTopScore() const {
assert(!m_queue.empty());
RuleCubeItem *item = m_queue.top();
return item->GetScore();
}
RuleCubeItem *Pop(ChartManager &);
bool IsEmpty() const { return m_queue.empty(); }
const ChartTranslationOption &GetTranslationOption() const {
return m_transOpt;
}
const std::vector<RuleCubeDimension> &GetCube() const {
return m_cube;
}
float GetCombinedScore() const {
return m_combinedScore;
}
void CreateNeighbors(RuleCubeQueue &) const;
bool operator<(const RuleCube &compare) const;
};
#ifdef HAVE_BOOST
std::size_t hash_value(const RuleCubeDimension &);
private:
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_set<RuleCubeItem*,
RuleCubeItemHasher,
RuleCubeItemEqualityPred
> ItemSet;
#else
typedef std::set<RuleCubeItem*, RuleCubeItemPositionOrderer> ItemSet;
#endif
typedef std::priority_queue<RuleCubeItem*,
std::vector<RuleCubeItem*>,
RuleCubeItemScoreOrderer
> Queue;
RuleCube(const RuleCube &); // Not implemented
RuleCube &operator=(const RuleCube &); // Not implemented
void CreateNeighbors(const RuleCubeItem &, ChartManager &);
void CreateNeighbor(const RuleCubeItem &, int, ChartManager &);
const ChartTranslationOption &m_transOpt;
ItemSet m_covered;
Queue m_queue;
};
}

141
moses/src/RuleCubeItem.cpp Normal file
View File

@ -0,0 +1,141 @@
/***********************************************************************
Moses - statistical machine translation system
Copyright (C) 2006-2011 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "ChartCell.h"
#include "ChartCellCollection.h"
#include "ChartTranslationOption.h"
#include "ChartTranslationOptionCollection.h"
#include "CoveredChartSpan.h"
#include "RuleCubeItem.h"
#include "RuleCubeQueue.h"
#include "WordsRange.h"
#include "Util.h"
#ifdef HAVE_BOOST
#include <boost/functional/hash.hpp>
#endif
namespace Moses
{
#ifdef HAVE_BOOST
std::size_t hash_value(const HypothesisDimension &dimension)
{
boost::hash<const ChartHypothesis*> hasher;
return hasher(dimension.GetHypothesis());
}
#endif
RuleCubeItem::RuleCubeItem(const ChartTranslationOption &transOpt,
const ChartCellCollection &allChartCells)
: m_translationDimension(0,
transOpt.GetTargetPhraseCollection().GetCollection())
, m_hypothesis(0)
{
const CoveredChartSpan *lastCCS = &transOpt.GetLastCoveredChartSpan();
CreateHypothesisDimensions(lastCCS, allChartCells);
}
// create the RuleCube from an existing one, differing only in one dimension
RuleCubeItem::RuleCubeItem(const RuleCubeItem &copy, int hypoDimensionIncr)
: m_translationDimension(copy.m_translationDimension)
, m_hypothesisDimensions(copy.m_hypothesisDimensions)
, m_hypothesis(0)
{
if (hypoDimensionIncr == -1) {
m_translationDimension.IncrementPos();
} else {
HypothesisDimension &dimension = m_hypothesisDimensions[hypoDimensionIncr];
dimension.IncrementPos();
}
}
RuleCubeItem::~RuleCubeItem()
{
delete m_hypothesis;
}
void RuleCubeItem::EstimateScore()
{
m_score = m_translationDimension.GetTargetPhrase()->GetFutureScore();
std::vector<HypothesisDimension>::const_iterator p;
for (p = m_hypothesisDimensions.begin();
p != m_hypothesisDimensions.end(); ++p) {
m_score += p->GetHypothesis()->GetTotalScore();
}
}
void RuleCubeItem::CreateHypothesis(const ChartTranslationOption &transOpt,
ChartManager &manager)
{
m_hypothesis = new ChartHypothesis(transOpt, *this, manager);
m_hypothesis->CalcScore();
m_score = m_hypothesis->GetTotalScore();
}
ChartHypothesis *RuleCubeItem::ReleaseHypothesis()
{
assert(m_hypothesis);
ChartHypothesis *hypo = m_hypothesis;
m_hypothesis = 0;
return hypo;
}
// for each non-terminal, create a ordered list of matching hypothesis from the
// chart
void RuleCubeItem::CreateHypothesisDimensions(
const CoveredChartSpan *coveredChartSpan,
const ChartCellCollection &allChartCells)
{
// recurse through the linked list of source side non-terminals and terminals
const CoveredChartSpan *prev = coveredChartSpan->GetPrevCoveredChartSpan();
if (prev) {
CreateHypothesisDimensions(prev, allChartCells);
}
// only deal with non-terminals
if (coveredChartSpan->IsNonTerminal()) {
// get the essential information about the non-terminal:
// span covered by child
const WordsRange &childRange = coveredChartSpan->GetWordsRange();
// list of all hypos for that span
const ChartCell &childCell = allChartCells.Get(childRange);
// target (sic!) non-terminal label
const Word &nonTerm = coveredChartSpan->GetSourceWord();
// there have to be hypothesis with the desired non-terminal
// (otherwise the rule would not be considered)
assert(!childCell.GetSortedHypotheses(nonTerm).empty());
// create a list of hypotheses that match the non-terminal
HypothesisDimension dimension(0, childCell.GetSortedHypotheses(nonTerm));
// add them to the vector for such lists
m_hypothesisDimensions.push_back(dimension);
}
}
bool RuleCubeItem::operator<(const RuleCubeItem &compare) const
{
if (m_translationDimension == compare.m_translationDimension) {
return m_hypothesisDimensions < compare.m_hypothesisDimensions;
}
return m_translationDimension < compare.m_translationDimension;
}
}

148
moses/src/RuleCubeItem.h Normal file
View File

@ -0,0 +1,148 @@
/***********************************************************************
Moses - statistical machine translation system
Copyright (C) 2006-2011 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#pragma once
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include <vector>
namespace Moses
{
class ChartCellCollection;
class ChartHypothesis;
class ChartManager;
class ChartTranslationOption;
class CoveredChartSpan;
class TargetPhrase;
typedef std::vector<const ChartHypothesis*> HypoList;
// wrapper around list of target phrase translation options
class TranslationDimension
{
public:
TranslationDimension(size_t pos,
const std::vector<TargetPhrase*> &orderedTargetPhrases)
: m_pos(pos)
, m_orderedTargetPhrases(&orderedTargetPhrases)
{}
size_t IncrementPos() { return m_pos++; }
bool HasMoreTranslations() const {
return m_pos+1 < m_orderedTargetPhrases->size();
}
const TargetPhrase *GetTargetPhrase() const {
return (*m_orderedTargetPhrases)[m_pos];
}
bool operator<(const TranslationDimension &compare) const {
return GetTargetPhrase() < compare.GetTargetPhrase();
}
bool operator==(const TranslationDimension &compare) const {
return GetTargetPhrase() == compare.GetTargetPhrase();
}
private:
size_t m_pos;
const std::vector<TargetPhrase*> *m_orderedTargetPhrases;
};
// wrapper around list of hypotheses for a particular non-term of a trans opt
class HypothesisDimension
{
public:
HypothesisDimension(size_t pos, const HypoList &orderedHypos)
: m_pos(pos)
, m_orderedHypos(&orderedHypos)
{}
size_t IncrementPos() { return m_pos++; }
bool HasMoreHypo() const {
return m_pos+1 < m_orderedHypos->size();
}
const ChartHypothesis *GetHypothesis() const {
return (*m_orderedHypos)[m_pos];
}
bool operator<(const HypothesisDimension &compare) const {
return GetHypothesis() < compare.GetHypothesis();
}
bool operator==(const HypothesisDimension &compare) const {
return GetHypothesis() == compare.GetHypothesis();
}
private:
size_t m_pos;
const HypoList *m_orderedHypos;
};
#ifdef HAVE_BOOST
std::size_t hash_value(const HypothesisDimension &);
#endif
class RuleCubeItem
{
public:
RuleCubeItem(const ChartTranslationOption &, const ChartCellCollection &);
RuleCubeItem(const RuleCubeItem &, int);
~RuleCubeItem();
const TranslationDimension &GetTranslationDimension() const {
return m_translationDimension;
}
const std::vector<HypothesisDimension> &GetHypothesisDimensions() const {
return m_hypothesisDimensions;
}
float GetScore() const { return m_score; }
void EstimateScore();
void CreateHypothesis(const ChartTranslationOption &, ChartManager &);
ChartHypothesis *ReleaseHypothesis();
bool operator<(const RuleCubeItem &) const;
private:
RuleCubeItem(const RuleCubeItem &); // Not implemented
RuleCubeItem &operator=(const RuleCubeItem &); // Not implemented
void CreateHypothesisDimensions(const CoveredChartSpan *,
const ChartCellCollection &);
TranslationDimension m_translationDimension;
std::vector<HypothesisDimension> m_hypothesisDimensions;
ChartHypothesis *m_hypothesis;
float m_score;
};
}

View File

@ -21,42 +21,48 @@
#include "RuleCubeQueue.h"
#include "Util.h"
using namespace std;
#include "RuleCubeItem.h"
#include "StaticData.h"
namespace Moses
{
RuleCubeQueue::~RuleCubeQueue()
{
RemoveAllInColl(m_uniqueEntry);
while (!m_queue.empty()) {
RuleCube *cube = m_queue.top();
m_queue.pop();
delete cube;
}
}
bool RuleCubeQueue::Add(RuleCube *ruleCube)
void RuleCubeQueue::Add(RuleCube *ruleCube)
{
pair<UniqueCubeEntry::iterator, bool> inserted = m_uniqueEntry.insert(ruleCube);
m_queue.push(ruleCube);
}
if (inserted.second) {
// inserted
m_sortedByScore.push(ruleCube);
ChartHypothesis *RuleCubeQueue::Pop()
{
// pop the most promising rule cube
RuleCube *cube = m_queue.top();
m_queue.pop();
// pop the most promising item from the cube and get the corresponding
// hypothesis
RuleCubeItem *item = cube->Pop(m_manager);
if (StaticData::Instance().GetCubePruningLazyScoring()) {
item->CreateHypothesis(cube->GetTranslationOption(), m_manager);
}
ChartHypothesis *hypo = item->ReleaseHypothesis();
// if the cube contains more items then push it back onto the queue
if (!cube->IsEmpty()) {
m_queue.push(cube);
} else {
// already there
//cerr << "already there\n";
delete ruleCube;
delete cube;
}
//assert(m_uniqueEntry.size() == m_sortedByScore.size());
return inserted.second;
}
RuleCube *RuleCubeQueue::Pop()
{
RuleCube *ruleCube = m_sortedByScore.top();
m_sortedByScore.pop();
return ruleCube;
return hypo;
}
}

View File

@ -18,85 +18,49 @@
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <queue>
#include <vector>
#include <set>
#pragma once
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include "RuleCube.h"
#ifdef HAVE_BOOST
#include <boost/functional/hash.hpp>
#include <boost/unordered_set.hpp>
#include <boost/version.hpp>
#endif
#include <queue>
#include <vector>
namespace Moses
{
#ifdef HAVE_BOOST
class RuleCubeUniqueHasher
class ChartManager;
// Define an ordering between RuleCube based on their best item scores. This
// is used to order items in the priority queue.
class RuleCubeOrderer
{
public:
size_t operator()(const RuleCube * p) const {
size_t seed = 0;
boost::hash_combine(seed, &(p->GetTranslationOption()));
boost::hash_combine(seed, p->GetCube());
return seed;
public:
bool operator()(const RuleCube *p, const RuleCube *q) const {
return p->GetTopScore() < q->GetTopScore();
}
};
class RuleCubeUniqueEqualityPred
{
public:
bool operator()(const RuleCube * p, const RuleCube * q) const {
return ((&(p->GetTranslationOption()) == &(q->GetTranslationOption()))
&& (p->GetCube() == q->GetCube()));
}
};
#endif
class RuleCubeUniqueOrderer
{
public:
bool operator()(const RuleCube* entryA, const RuleCube* entryB) const {
return (*entryA) < (*entryB);
}
};
class RuleCubeScoreOrderer
{
public:
bool operator()(const RuleCube* entryA, const RuleCube* entryB) const {
return (entryA->GetCombinedScore() < entryB->GetCombinedScore());
}
};
class RuleCubeQueue
{
protected:
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_set<RuleCube*,
RuleCubeUniqueHasher,
RuleCubeUniqueEqualityPred> UniqueCubeEntry;
#else
typedef std::set<RuleCube*, RuleCubeUniqueOrderer> UniqueCubeEntry;
#endif
UniqueCubeEntry m_uniqueEntry;
typedef std::priority_queue<RuleCube*, std::vector<RuleCube*>, RuleCubeScoreOrderer> SortedByScore;
SortedByScore m_sortedByScore;
public:
public:
RuleCubeQueue(ChartManager &manager) : m_manager(manager) {}
~RuleCubeQueue();
bool IsEmpty() const {
return m_sortedByScore.empty();
}
RuleCube *Pop();
bool Add(RuleCube *ruleCube);
};
void Add(RuleCube *);
ChartHypothesis *Pop();
bool IsEmpty() const { return m_queue.empty(); }
private:
typedef std::priority_queue<RuleCube*, std::vector<RuleCube*>,
RuleCubeOrderer > Queue;
Queue m_queue;
ChartManager &m_manager;
};
}

View File

@ -333,6 +333,8 @@ bool StaticData::LoadData(Parameter *parameter)
m_cubePruningDiversity = (m_parameter->GetParam("cube-pruning-diversity").size() > 0)
? Scan<size_t>(m_parameter->GetParam("cube-pruning-diversity")[0]) : DEFAULT_CUBE_PRUNING_DIVERSITY;
SetBooleanParameter(&m_cubePruningLazyScoring, "cube-pruning-lazy-scoring", false);
// unknown word processing
SetBooleanParameter( &m_dropUnknown, "drop-unknown", false );

View File

@ -195,6 +195,7 @@ protected:
size_t m_cubePruningPopLimit;
size_t m_cubePruningDiversity;
bool m_cubePruningLazyScoring;
size_t m_ruleLimit;
@ -313,6 +314,9 @@ public:
size_t GetCubePruningDiversity() const {
return m_cubePruningDiversity;
}
bool GetCubePruningLazyScoring() const {
return m_cubePruningLazyScoring;
}
size_t IsPathRecoveryEnabled() const {
return m_recoverPath;
}

View File

@ -55,6 +55,25 @@ void TargetPhraseCollection::Prune(bool adhereTableLimit, size_t tableLimit)
}
}
void TargetPhraseCollection::Sort(bool adhereTableLimit, size_t tableLimit)
{
std::vector<TargetPhrase*>::iterator iterMiddle;
iterMiddle = (tableLimit == 0 || m_collection.size() < tableLimit)
? m_collection.end()
: m_collection.begin()+tableLimit;
std::partial_sort(m_collection.begin(), iterMiddle, m_collection.end(),
CompareTargetPhrase());
if (adhereTableLimit && m_collection.size() > tableLimit) {
for (size_t i = tableLimit; i < m_collection.size(); ++i) {
TargetPhrase *targetPhrase = m_collection[i];
delete targetPhrase;
}
m_collection.erase(m_collection.begin()+tableLimit, m_collection.end());
}
}
}

View File

@ -57,6 +57,8 @@ public:
RemoveAllInColl(m_collection);
}
const std::vector<TargetPhrase*> &GetCollection() const { return m_collection; }
//! divide collection into 2 buckets using std::nth_element, the top & bottom according to table limit
void NthElement(size_t tableLimit);
@ -74,6 +76,7 @@ public:
}
void Prune(bool adhereTableLimit, size_t tableLimit);
void Sort(bool adhereTableLimit, size_t tableLimit);
};