diff --git a/OnDiskPt/src/TargetPhraseCollection.cpp b/OnDiskPt/src/TargetPhraseCollection.cpp index c6364cafe..910af9ea9 100644 --- a/OnDiskPt/src/TargetPhraseCollection.cpp +++ b/OnDiskPt/src/TargetPhraseCollection.cpp @@ -144,7 +144,7 @@ Moses::TargetPhraseCollection *TargetPhraseCollection::ConvertToMoses(const std: ret->Add(mosesPhrase); } - ret->Prune(true, phraseDict.GetTableLimit()); + ret->Sort(true, phraseDict.GetTableLimit()); return ret; diff --git a/moses-chart-cmd/src/IOWrapper.cpp b/moses-chart-cmd/src/IOWrapper.cpp index 39498104b..31ec4c4ee 100644 --- a/moses-chart-cmd/src/IOWrapper.cpp +++ b/moses-chart-cmd/src/IOWrapper.cpp @@ -43,6 +43,10 @@ POSSIBILITY OF SUCH DAMAGE. #include "PhraseDictionary.h" #include "ChartTrellisPathList.h" #include "ChartTrellisPath.h" +#include "ChartTranslationOption.h" +#include "ChartHypothesis.h" +#include "CoveredChartSpan.h" + using namespace std; using namespace Moses; @@ -223,7 +227,11 @@ void OutputTranslationOptions(std::ostream &out, const ChartHypothesis *hypo, lo { // recursive if (hypo != NULL) { - out << "Trans Opt " << translationId << " " << hypo->GetCurrSourceRange() << ": " << hypo->GetTranslationOption() + out << "Trans Opt " << translationId + << " " << hypo->GetCurrSourceRange() + << ": " << hypo->GetTranslationOption().GetLastCoveredChartSpan() + << ": " << hypo->GetCurrTargetPhrase().GetTargetLHS() + << "->" << hypo->GetCurrTargetPhrase() << " " << hypo->GetTotalScore() << hypo->GetScoreBreakdown() << endl; } diff --git a/moses/src/ChartCell.cpp b/moses/src/ChartCell.cpp index 521bb7ec5..3a05f204a 100644 --- a/moses/src/ChartCell.cpp +++ b/moses/src/ChartCell.cpp @@ -32,7 +32,6 @@ #include "ChartTranslationOptionList.h" using namespace std; -using namespace Moses; namespace Moses { @@ -83,14 +82,14 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList const StaticData &staticData = StaticData::Instance(); // priority queue for applicable rules with selected hypotheses - RuleCubeQueue queue; + RuleCubeQueue queue(m_manager); // add all trans opt into queue. using only 1st child node. ChartTranslationOptionList::const_iterator iterList; for (iterList = transOptList.begin(); iterList != transOptList.end(); ++iterList) { const ChartTranslationOption &transOpt = **iterList; - RuleCube *ruleCube = new RuleCube(transOpt, allChartCells); + RuleCube *ruleCube = new RuleCube(transOpt, allChartCells, m_manager); queue.Add(ruleCube); } @@ -98,16 +97,8 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList const size_t popLimit = staticData.GetCubePruningPopLimit(); for (size_t numPops = 0; numPops < popLimit && !queue.IsEmpty(); ++numPops) { - RuleCube *ruleCube = queue.Pop(); - - // create hypothesis from RuleCube - ChartHypothesis *hypo = new ChartHypothesis(*ruleCube, m_manager); - assert(hypo); - hypo->CalcScore(); + ChartHypothesis *hypo = queue.Pop(); AddHypothesis(hypo); - - // add neighbors to the queue - ruleCube->CreateNeighbors(queue); } } @@ -223,5 +214,3 @@ std::ostream& operator<<(std::ostream &out, const ChartCell &cell) } } // namespace - - diff --git a/moses/src/ChartHypothesis.cpp b/moses/src/ChartHypothesis.cpp index fc13629ea..ee701576b 100644 --- a/moses/src/ChartHypothesis.cpp +++ b/moses/src/ChartHypothesis.cpp @@ -22,7 +22,7 @@ #include #include #include "ChartHypothesis.h" -#include "RuleCube.h" +#include "RuleCubeItem.h" #include "ChartCell.h" #include "ChartManager.h" #include "TargetPhrase.h" @@ -33,9 +33,6 @@ #include "ChartTranslationOption.h" #include "FFState.h" -using namespace std; -using namespace Moses; - namespace Moses { unsigned int ChartHypothesis::s_HypothesesCreated = 0; @@ -45,32 +42,33 @@ ObjectPool ChartHypothesis::s_objectPool("ChartHypothesis", 300 #endif /** Create a hypothesis from a rule */ -ChartHypothesis::ChartHypothesis(const RuleCube &ruleCube, ChartManager &manager) - :m_transOpt(ruleCube.GetTranslationOption()) - ,m_id(++s_HypothesesCreated) - ,m_currSourceWordsRange(ruleCube.GetTranslationOption().GetSourceWordsRange()) - ,m_ffStates(manager.GetTranslationSystem()->GetStatefulFeatureFunctions().size()) +ChartHypothesis::ChartHypothesis(const ChartTranslationOption &transOpt, + const RuleCubeItem &item, + ChartManager &manager) + :m_id(++s_HypothesesCreated) + ,m_targetPhrase(*(item.GetTranslationDimension().GetTargetPhrase())) + ,m_transOpt(transOpt) ,m_contextPrefix(Output, manager.GetTranslationSystem()->GetLanguageModels().GetMaxNGramOrder()) ,m_contextSuffix(Output, manager.GetTranslationSystem()->GetLanguageModels().GetMaxNGramOrder()) + ,m_currSourceWordsRange(transOpt.GetSourceWordsRange()) + ,m_ffStates(manager.GetTranslationSystem()->GetStatefulFeatureFunctions().size()) ,m_arcList(NULL) ,m_winningHypo(NULL) ,m_manager(manager) { - //TRACE_ERR(m_targetPhrase << endl); - // underlying hypotheses for sub-spans m_numTargetTerminals = GetCurrTargetPhrase().GetNumTerminals(); - const std::vector &childEntries = ruleCube.GetCube(); + const std::vector &childEntries = item.GetHypothesisDimensions(); // ... are stored assert(m_prevHypos.empty()); m_prevHypos.reserve(childEntries.size()); - vector::const_iterator iter; + std::vector::const_iterator iter; for (iter = childEntries.begin(); iter != childEntries.end(); ++iter) { - const RuleCubeDimension &ruleCubeDimension = *iter; - const ChartHypothesis *prevHypo = ruleCubeDimension.GetHypothesis(); + const HypothesisDimension &dimension = *iter; + const ChartHypothesis *prevHypo = dimension.GetHypothesis(); // keep count of words (= length of generated string) m_numTargetTerminals += prevHypo->GetNumTargetTerminals(); @@ -179,8 +177,8 @@ size_t ChartHypothesis::CalcSuffix(Phrase &ret, size_t size) const // special handling for small hypotheses // does the prefix match the entire hypothesis string? -> just copy prefix if (m_contextPrefix.GetSize() == m_numTargetTerminals) { - size_t maxCount = min(m_contextPrefix.GetSize(), size) - , pos = m_contextPrefix.GetSize() - 1; + size_t maxCount = std::min(m_contextPrefix.GetSize(), size); + size_t pos= m_contextPrefix.GetSize() - 1; for (size_t ind = 0; ind < maxCount; ++ind) { const Word &word = m_contextPrefix.GetWord(pos); @@ -267,7 +265,7 @@ void ChartHypothesis::CalcScore() // sfs[i]->ChartEvaluate(m_targetPhrase, &m_scoreBreakdown); //} - const vector& ffs = + const std::vector& ffs = m_manager.GetTranslationSystem()->GetStatefulFeatureFunctions(); for (unsigned i = 0; i < ffs.size(); ++i) { m_ffStates[i] = ffs[i]->EvaluateChart(*this,i,&m_scoreBreakdown); @@ -361,7 +359,7 @@ void ChartHypothesis::SetWinningHypo(const ChartHypothesis *hypo) TO_STRING_BODY(ChartHypothesis) // friend -ostream& operator<<(ostream& out, const ChartHypothesis& hypo) +std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo) { out << hypo.GetId(); @@ -392,4 +390,3 @@ ostream& operator<<(ostream& out, const ChartHypothesis& hypo) } } - diff --git a/moses/src/ChartHypothesis.h b/moses/src/ChartHypothesis.h index 5d76a67d6..983213752 100644 --- a/moses/src/ChartHypothesis.h +++ b/moses/src/ChartHypothesis.h @@ -31,9 +31,10 @@ namespace Moses { -class RuleCube; + class ChartHypothesis; class ChartManager; +class RuleCubeItem; typedef std::vector ChartArcList; @@ -50,6 +51,7 @@ protected: static unsigned int s_HypothesesCreated; int m_id; /**< numeric ID of this hypothesis, used for logging */ + const TargetPhrase &m_targetPhrase; const ChartTranslationOption &m_transOpt; Phrase m_contextPrefix, m_contextSuffix; @@ -97,7 +99,9 @@ public: } #endif - explicit ChartHypothesis(const RuleCube &ruleCube, ChartManager &manager); + ChartHypothesis(const ChartTranslationOption &, const RuleCubeItem &item, + ChartManager &manager); + ~ChartHypothesis(); int GetId()const { @@ -107,7 +111,7 @@ public: return m_transOpt; } const TargetPhrase &GetCurrTargetPhrase()const { - return m_transOpt.GetTargetPhrase(); + return m_targetPhrase; } const WordsRange &GetCurrSourceRange()const { return m_currSourceWordsRange; diff --git a/moses/src/ChartRuleLookupManagerOnDisk.cpp b/moses/src/ChartRuleLookupManagerOnDisk.cpp index 7498771a3..9e34c2675 100644 --- a/moses/src/ChartRuleLookupManagerOnDisk.cpp +++ b/moses/src/ChartRuleLookupManagerOnDisk.cpp @@ -268,8 +268,10 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( } assert(targetPhraseCollection); - outColl.Add(*targetPhraseCollection, *coveredChartSpan, - GetCellCollection(), adhereTableLimit, rulesLimit); + if (!targetPhraseCollection->IsEmpty()) { + outColl.Add(*targetPhraseCollection, *coveredChartSpan, + GetCellCollection(), adhereTableLimit, rulesLimit); + } } // if (node) diff --git a/moses/src/ChartTranslationOption.cpp b/moses/src/ChartTranslationOption.cpp index ab886034a..9b228f419 100644 --- a/moses/src/ChartTranslationOption.cpp +++ b/moses/src/ChartTranslationOption.cpp @@ -19,12 +19,12 @@ ***********************************************************************/ #include "ChartTranslationOption.h" -#include "TargetPhrase.h" -#include "AlignmentInfo.h" -#include "CoveredChartSpan.h" -#include "ChartCellCollection.h" -using namespace std; +#include "AlignmentInfo.h" +#include "ChartCellCollection.h" +#include "CoveredChartSpan.h" + +#include namespace Moses { @@ -54,18 +54,11 @@ void ChartTranslationOption::CalcEstimateOfBestScore( assert(!childCell.GetSortedHypotheses(nonTerm).empty()); // create a list of hypotheses that match the non-terminal - const vector &stack = + const std::vector &stack = childCell.GetSortedHypotheses(nonTerm); const ChartHypothesis *hypo = stack[0]; m_estimateOfBestScore += hypo->GetTotalScore(); } } -std::ostream& operator<<(std::ostream &out, const ChartTranslationOption &rule) -{ - out << rule.m_lastCoveredChartSpan << ": " << rule.m_targetPhrase.GetTargetLHS() << "->" << rule.m_targetPhrase; - return out; } - -} // namespace - diff --git a/moses/src/ChartTranslationOption.h b/moses/src/ChartTranslationOption.h index 56c2de74d..a9746877f 100644 --- a/moses/src/ChartTranslationOption.h +++ b/moses/src/ChartTranslationOption.h @@ -20,70 +20,68 @@ #pragma once +#include "TargetPhrase.h" +#include "TargetPhraseCollection.h" +#include "WordsRange.h" + #include #include -#include "Word.h" -#include "WordsRange.h" -#include "TargetPhrase.h" namespace Moses { + class CoveredChartSpan; class ChartCellCollection; -// basically a phrase translation and the vector of words consumed to map each word +// Similar to a DottedRule, but contains a direct reference to a list +// of translations and provdes an estimate of the best score. class ChartTranslationOption { - friend std::ostream& operator<<(std::ostream&, const ChartTranslationOption&); - -protected: - const TargetPhrase &m_targetPhrase; - const CoveredChartSpan &m_lastCoveredChartSpan; - /* map each source word in the phrase table to: - 1. a word in the input sentence, if the pt word is a terminal - 2. a 1+ phrase in the input sentence, if the pt word is a non-terminal - */ - const WordsRange &m_wordsRange; - - float m_estimateOfBestScore; - - ChartTranslationOption &operator=(const ChartTranslationOption &); // not implemented - - void CalcEstimateOfBestScore(const CoveredChartSpan *, const ChartCellCollection &); - -public: - ChartTranslationOption(const TargetPhrase &targetPhrase, const CoveredChartSpan &lastCoveredChartSpan, const WordsRange &wordsRange, const ChartCellCollection &allChartCells) - :m_targetPhrase(targetPhrase) - ,m_lastCoveredChartSpan(lastCoveredChartSpan) - ,m_wordsRange(wordsRange) - ,m_estimateOfBestScore(m_targetPhrase.GetFutureScore()) + public: + ChartTranslationOption(const TargetPhraseCollection &targetPhraseColl, + const CoveredChartSpan &lastCoveredChartSpan, + const WordsRange &wordsRange, + const ChartCellCollection &allChartCells) + : m_lastCoveredChartSpan(lastCoveredChartSpan) + , m_targetPhraseCollection(targetPhraseColl) + , m_wordsRange(wordsRange) + , m_estimateOfBestScore(0) { + const TargetPhrase &targetPhrase = **(m_targetPhraseCollection.begin()); + m_estimateOfBestScore = targetPhrase.GetFutureScore(); CalcEstimateOfBestScore(&m_lastCoveredChartSpan, allChartCells); } - ~ChartTranslationOption() - {} - - const TargetPhrase &GetTargetPhrase() const { - return m_targetPhrase; - } + ~ChartTranslationOption() {} const CoveredChartSpan &GetLastCoveredChartSpan() const { return m_lastCoveredChartSpan; } + const TargetPhraseCollection &GetTargetPhraseCollection() const { + return m_targetPhraseCollection; + } + const WordsRange &GetSourceWordsRange() const { return m_wordsRange; } // return an estimate of the best score possible with this translation option. - // the estimate is the sum of the target phrase's estimated score plus the - // scores of the best child hypotheses. (the same as the ordering criterion - // currently used in RuleCubeQueue.) - inline float GetEstimateOfBestScore() const { - return m_estimateOfBestScore; - } + // the estimate is the sum of the top target phrase's estimated score plus the + // scores of the best child hypotheses. + inline float GetEstimateOfBestScore() const { return m_estimateOfBestScore; } + private: + // not implemented + ChartTranslationOption &operator=(const ChartTranslationOption &); + + void CalcEstimateOfBestScore(const CoveredChartSpan *, + const ChartCellCollection &); + + const CoveredChartSpan &m_lastCoveredChartSpan; + const TargetPhraseCollection &m_targetPhraseCollection; + const WordsRange &m_wordsRange; + float m_estimateOfBestScore; }; } diff --git a/moses/src/ChartTranslationOptionCollection.cpp b/moses/src/ChartTranslationOptionCollection.cpp index 806609822..0c1152af2 100644 --- a/moses/src/ChartTranslationOptionCollection.cpp +++ b/moses/src/ChartTranslationOptionCollection.cpp @@ -30,7 +30,6 @@ #include "Util.h" using namespace std; -using namespace Moses; namespace Moses { @@ -43,8 +42,8 @@ ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType con ,m_system(system) ,m_decodeGraphList(system->GetDecodeGraphs()) ,m_hypoStackColl(hypoStackColl) - ,m_collection(source.GetSize()) ,m_ruleLookupManagers(ruleLookupManagers) + ,m_collection(source.GetSize()) { // create 2-d vector size_t size = source.GetSize(); @@ -59,7 +58,7 @@ ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType con ChartTranslationOptionCollection::~ChartTranslationOptionCollection() { RemoveAllInColl(m_unksrcs); - RemoveAllInColl(m_cacheTargetPhrase); + RemoveAllInColl(m_cacheTargetPhraseCollection); std::list* >::iterator iterOuter; for (iterOuter = m_coveredChartSpanCache.begin(); iterOuter != m_coveredChartSpanCache.end(); ++iterOuter) { @@ -225,8 +224,10 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW // add to dictionary TargetPhrase *targetPhrase = new TargetPhrase(Output); + TargetPhraseCollection *tpc = new TargetPhraseCollection; + tpc->Add(targetPhrase); - m_cacheTargetPhrase.push_back(targetPhrase); + m_cacheTargetPhraseCollection.push_back(tpc); Word &targetWord = targetPhrase->AddWord(); targetWord.CreateUnknownWord(sourceWord); @@ -240,7 +241,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW targetPhrase->SetTargetLHS(targetLHS); // chart rule - ChartTranslationOption *chartRule = new ChartTranslationOption(*targetPhrase + ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc , *coveredChartSpanList->back() , range , m_hypoStackColl); @@ -251,6 +252,8 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW vector unknownScore(1, FloorScore(-numeric_limits::infinity())); TargetPhrase *targetPhrase = new TargetPhrase(Output); + TargetPhraseCollection *tpc = new TargetPhraseCollection; + tpc->Add(targetPhrase); // loop const UnknownLHSList &lhsList = staticData.GetUnknownLHS(); UnknownLHSList::const_iterator iterLHS; @@ -262,7 +265,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW targetLHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), targetLHSStr, true); assert(targetLHS.GetFactor(0) != NULL); - m_cacheTargetPhrase.push_back(targetPhrase); + m_cacheTargetPhraseCollection.push_back(tpc); targetPhrase->SetSourcePhrase(m_unksrc); targetPhrase->SetScore(unknownWordPenaltyProducer, unknownScore); targetPhrase->SetTargetLHS(targetLHS); @@ -274,7 +277,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW // chart rule assert(coveredChartSpanList->size()); - ChartTranslationOption *chartRule = new ChartTranslationOption(*targetPhrase + ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc , *coveredChartSpanList->back() , range , m_hypoStackColl); @@ -302,7 +305,4 @@ void ChartTranslationOptionCollection::Sort(size_t startPos, size_t endPos) list.Sort(); } - } // namespace - - diff --git a/moses/src/ChartTranslationOptionCollection.h b/moses/src/ChartTranslationOptionCollection.h index 20790e1c2..cf47ea561 100644 --- a/moses/src/ChartTranslationOptionCollection.h +++ b/moses/src/ChartTranslationOptionCollection.h @@ -40,15 +40,15 @@ class ChartTranslationOptionCollection { friend std::ostream& operator<<(std::ostream&, const ChartTranslationOptionCollection&); protected: - const InputType &m_source; + const InputType &m_source; const TranslationSystem* m_system; std::vector m_decodeGraphList; const ChartCellCollection &m_hypoStackColl; const std::vector &m_ruleLookupManagers; - std::vector< std::vector< ChartTranslationOptionList > > m_collection; /*< contains translation options */ + std::vector< std::vector< ChartTranslationOptionList > > m_collection; /*< contains translation options */ std::vector m_unksrcs; - std::list m_cacheTargetPhrase; + std::list m_cacheTargetPhraseCollection; std::list* > m_coveredChartSpanCache; // for adding 1 trans opt in unknown word proc diff --git a/moses/src/ChartTranslationOptionList.cpp b/moses/src/ChartTranslationOptionList.cpp index f51da76f9..ad81a25cc 100644 --- a/moses/src/ChartTranslationOptionList.cpp +++ b/moses/src/ChartTranslationOptionList.cpp @@ -26,11 +26,9 @@ #include "ChartCellCollection.h" #include "WordsRange.h" -using namespace std; -using namespace Moses; - namespace Moses { + #ifdef USE_HYPO_POOL ObjectPool ChartTranslationOptionList::s_objectPool("ChartTranslationOptionList", 3000); #endif @@ -61,48 +59,44 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &targetPhraseC , bool /* adhereTableLimit */ , size_t ruleLimit) { - TargetPhraseCollection::const_iterator iter; - TargetPhraseCollection::const_iterator iterEnd = targetPhraseCollection.end(); + if (targetPhraseCollection.IsEmpty()) { + return; + } - for (iter = targetPhraseCollection.begin(); iter != iterEnd; ++iter) { - const TargetPhrase &targetPhrase = **iter; - - if (m_collection.size() < ruleLimit) { - // not yet filled out quota. add everything - ChartTranslationOption *option = new ChartTranslationOption( - targetPhrase, coveredChartSpan, m_range, chartCellColl); - m_collection.push_back(option); - float score = option->GetEstimateOfBestScore(); - m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold; - } - else { - // full but not bursting. add if better than worst score - ChartTranslationOption option(targetPhrase, coveredChartSpan, m_range, - chartCellColl); - float score = option.GetEstimateOfBestScore(); - if (score > m_scoreThreshold) { - // dynamic allocation deferred until here on the assumption that most - // options will score below the threshold. - m_collection.push_back(new ChartTranslationOption(option)); - } + if (m_collection.size() < ruleLimit) { + // not yet filled out quota. add everything + ChartTranslationOption *option = new ChartTranslationOption( + targetPhraseCollection, coveredChartSpan, m_range, chartCellColl); + m_collection.push_back(option); + float score = option->GetEstimateOfBestScore(); + m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold; + } + else { + // full but not bursting. add if better than worst score + ChartTranslationOption option(targetPhraseCollection, coveredChartSpan, + m_range, chartCellColl); + float score = option.GetEstimateOfBestScore(); + if (score > m_scoreThreshold) { + // dynamic allocation deferred until here on the assumption that most + // options will score below the threshold. + m_collection.push_back(new ChartTranslationOption(option)); } + } - // prune if bursting - if (m_collection.size() > ruleLimit * 2) { - std::nth_element(m_collection.begin() - , m_collection.begin() + ruleLimit - , m_collection.end() - , ChartTranslationOptionOrderer()); - // delete the bottom half - for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) { - // make the best score of bottom half the score threshold - float score = m_collection[ind]->GetEstimateOfBestScore(); - m_scoreThreshold = (score > m_scoreThreshold) ? score : m_scoreThreshold; - delete m_collection[ind]; - } - m_collection.resize(ruleLimit); + // prune if bursting + if (m_collection.size() > ruleLimit * 2) { + std::nth_element(m_collection.begin() + , m_collection.begin() + ruleLimit + , m_collection.end() + , ChartTranslationOptionOrderer()); + // delete the bottom half + for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) { + // make the best score of bottom half the score threshold + float score = m_collection[ind]->GetEstimateOfBestScore(); + m_scoreThreshold = (score > m_scoreThreshold) ? score : m_scoreThreshold; + delete m_collection[ind]; } - + m_collection.resize(ruleLimit); } } @@ -157,15 +151,4 @@ void ChartTranslationOptionList::Sort() std::sort(m_collection.begin(), m_collection.end(), ChartTranslationOptionOrderer()); } -std::ostream& operator<<(std::ostream &out, const ChartTranslationOptionList &coll) -{ - ChartTranslationOptionList::const_iterator iter; - for (iter = coll.begin() ; iter != coll.end() ; ++iter) { - const ChartTranslationOption &rule = **iter; - out << rule << endl; - } - return out; } - -} - diff --git a/moses/src/ChartTrellisNode.cpp b/moses/src/ChartTrellisNode.cpp index 58ff9fa01..6473a731a 100644 --- a/moses/src/ChartTrellisNode.cpp +++ b/moses/src/ChartTrellisNode.cpp @@ -21,6 +21,7 @@ #include "ChartTrellisNode.h" #include "ChartHypothesis.h" +#include "CoveredChartSpan.h" #include "ScoreComponentCollection.h" #include "StaticData.h" @@ -103,8 +104,8 @@ Phrase ChartTrellisNode::GetOutputPhrase() const const ChartTranslationOption &transOpt = m_hypo->GetTranslationOption(); - VERBOSE(3, "Trans Opt:" << transOpt << std::endl); - + VERBOSE(3, "Trans Opt:" << transOpt.GetLastCoveredChartSpan() << ": " << m_hypo->GetCurrTargetPhrase().GetTargetLHS() << "->" << m_hypo->GetCurrTargetPhrase() << std::endl); + const Phrase &currTargetPhrase = m_hypo->GetCurrTargetPhrase(); const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = m_hypo->GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap(); diff --git a/moses/src/Makefile.am b/moses/src/Makefile.am index b5296614c..35dc1cc9a 100644 --- a/moses/src/Makefile.am +++ b/moses/src/Makefile.am @@ -92,8 +92,9 @@ libmoses_la_HEADERS = \ PrefixTreeMap.h \ ReorderingConstraint.h \ ReorderingStack.h \ - RuleCube.h \ - RuleCubeQueue.h \ + RuleCube.h \ + RuleCubeItem.h \ + RuleCubeQueue.h \ ScoreComponentCollection.h \ ScoreIndexManager.h \ ScoreProducer.h \ @@ -245,8 +246,9 @@ libmoses_la_SOURCES = \ PrefixTreeMap.cpp \ ReorderingConstraint.cpp \ ReorderingStack.cpp \ - RuleCube.cpp \ - RuleCubeQueue.cpp \ + RuleCube.cpp \ + RuleCubeItem.cpp \ + RuleCubeQueue.cpp \ ScoreComponentCollection.cpp \ ScoreIndexManager.cpp \ ScoreProducer.cpp \ diff --git a/moses/src/Parameter.cpp b/moses/src/Parameter.cpp index 91422e6c8..541628325 100644 --- a/moses/src/Parameter.cpp +++ b/moses/src/Parameter.cpp @@ -124,6 +124,7 @@ Parameter::Parameter() #endif AddParam("cube-pruning-pop-limit", "cbp", "How many hypotheses should be popped for each stack. (default = 1000)"); AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)"); + AddParam("cube-pruning-lazy-scoring", "cbls", "Don't fully score a hypothesis until it is popped"); AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)"); AddParam("constraint", "Location of the file with target sentences to produce constraining the search"); AddParam("use-alignment-info", "Use word-to-word alignment: actually it is only used to output the word-to-word alignment. Word-to-word alignments are taken from the phrase table if any. Default is false."); diff --git a/moses/src/PhraseDictionaryNodeSCFG.cpp b/moses/src/PhraseDictionaryNodeSCFG.cpp index 63aeeee21..7f80f8fff 100644 --- a/moses/src/PhraseDictionaryNodeSCFG.cpp +++ b/moses/src/PhraseDictionaryNodeSCFG.cpp @@ -46,6 +46,22 @@ void PhraseDictionaryNodeSCFG::Prune(size_t tableLimit) m_targetPhraseCollection->Prune(true, tableLimit); } +void PhraseDictionaryNodeSCFG::Sort(size_t tableLimit) +{ + // recusively sort + for (TerminalMap::iterator p = m_sourceTermMap.begin(); p != m_sourceTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + for (NonTerminalMap::iterator p = m_nonTermMap.begin(); p != m_nonTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + + // prune TargetPhraseCollection in this node + if (m_targetPhraseCollection != NULL) { + m_targetPhraseCollection->Sort(true, tableLimit); + } +} + PhraseDictionaryNodeSCFG *PhraseDictionaryNodeSCFG::GetOrCreateChild(const Word &sourceTerm) { assert(!sourceTerm.IsNonTerminal()); diff --git a/moses/src/PhraseDictionaryNodeSCFG.h b/moses/src/PhraseDictionaryNodeSCFG.h index 2fc8e4fd0..613d872ab 100644 --- a/moses/src/PhraseDictionaryNodeSCFG.h +++ b/moses/src/PhraseDictionaryNodeSCFG.h @@ -170,6 +170,7 @@ public: } void Prune(size_t tableLimit); + void Sort(size_t tableLimit); PhraseDictionaryNodeSCFG *GetOrCreateChild(const Word &sourceTerm); PhraseDictionaryNodeSCFG *GetOrCreateChild(const Word &sourceNonTerm, const Word &targetNonTerm); const PhraseDictionaryNodeSCFG *GetChild(const Word &sourceTerm) const; diff --git a/moses/src/PhraseDictionarySCFG.cpp b/moses/src/PhraseDictionarySCFG.cpp index 3d7cfac1d..0f6c53fec 100644 --- a/moses/src/PhraseDictionarySCFG.cpp +++ b/moses/src/PhraseDictionarySCFG.cpp @@ -176,7 +176,7 @@ bool PhraseDictionarySCFG::Load(const std::vector &input // prune each target phrase collection if (m_tableLimit) { - m_collection.Prune(m_tableLimit); + m_collection.Sort(m_tableLimit); } return true; diff --git a/moses/src/RuleCube.cpp b/moses/src/RuleCube.cpp index 4cf954a7e..d7035a240 100644 --- a/moses/src/RuleCube.cpp +++ b/moses/src/RuleCube.cpp @@ -19,138 +19,87 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA ***********************************************************************/ -#include "RuleCube.h" #include "ChartCell.h" -#include "ChartTranslationOptionCollection.h" #include "ChartCellCollection.h" -#include "RuleCubeQueue.h" -#include "WordsRange.h" #include "ChartTranslationOption.h" -#include "Util.h" +#include "ChartTranslationOptionCollection.h" #include "CoveredChartSpan.h" +#include "RuleCube.h" +#include "RuleCubeQueue.h" +#include "StaticData.h" +#include "Util.h" +#include "WordsRange.h" #ifdef HAVE_BOOST #include #endif -using namespace std; -using namespace Moses; - namespace Moses { -// create a cube for a rule -RuleCube::RuleCube(const ChartTranslationOption &transOpt - , const ChartCellCollection &allChartCells) - :m_transOpt(transOpt) +// initialise the RuleCube by creating the top-left corner item +RuleCube::RuleCube(const ChartTranslationOption &transOpt, + const ChartCellCollection &allChartCells, + ChartManager &manager) + : m_transOpt(transOpt) { - const CoveredChartSpan *coveredChartSpan = &transOpt.GetLastCoveredChartSpan(); - CreateRuleCubeDimension(coveredChartSpan, allChartCells); - CalcScore(); -} - -// for each non-terminal, create a ordered list of matching hypothesis from the chart -void RuleCube::CreateRuleCubeDimension(const CoveredChartSpan *coveredChartSpan, const ChartCellCollection &allChartCells) -{ - // recurse through the linked list of source side non-terminals and terminals - const CoveredChartSpan *prevCoveredChartSpan = coveredChartSpan->GetPrevCoveredChartSpan(); - if (prevCoveredChartSpan) - CreateRuleCubeDimension(prevCoveredChartSpan, allChartCells); - - // only deal with non-terminals - if (coveredChartSpan->IsNonTerminal()) - { - // get the essential information about the non-terminal - const WordsRange &childRange = coveredChartSpan->GetWordsRange(); // span covered by child - const ChartCell &childCell = allChartCells.Get(childRange); // list of all hypos for that span - const Word &nonTerm = coveredChartSpan->GetSourceWord(); // target (sic!) non-terminal label - - // there have to be hypothesis with the desired non-terminal - // (otherwise the rule would not be considered) - assert(!childCell.GetSortedHypotheses(nonTerm).empty()); - - // create a list of hypotheses that match the non-terminal - RuleCubeDimension ruleCubeDimension(0, childCell.GetSortedHypotheses(nonTerm)); - // add them to the vector for such lists - m_cube.push_back(ruleCubeDimension); + RuleCubeItem *item = new RuleCubeItem(transOpt, allChartCells); + m_covered.insert(item); + if (StaticData::Instance().GetCubePruningLazyScoring()) { + item->EstimateScore(); + } else { + item->CreateHypothesis(transOpt, manager); } -} - -// create the RuleCube from an existing one, differing only in one child hypothesis -RuleCube::RuleCube(const RuleCube ©, size_t ruleCubeDimensionIncr) - :m_transOpt(copy.m_transOpt) - ,m_cube(copy.m_cube) -{ - RuleCubeDimension &ruleCubeDimension = m_cube[ruleCubeDimensionIncr]; - ruleCubeDimension.IncrementPos(); - CalcScore(); + m_queue.push(item); } RuleCube::~RuleCube() { - //RemoveAllInColl(m_cube); + RemoveAllInColl(m_covered); +} + +RuleCubeItem *RuleCube::Pop(ChartManager &manager) +{ + RuleCubeItem *item = m_queue.top(); + m_queue.pop(); + CreateNeighbors(*item, manager); + return item; } // create new RuleCube for neighboring principle rules -// (duplicate detection is handled in RuleCubeQueue) -void RuleCube::CreateNeighbors(RuleCubeQueue &queue) const +void RuleCube::CreateNeighbors(const RuleCubeItem &item, ChartManager &manager) { - // loop over all child hypotheses - for (size_t ind = 0; ind < m_cube.size(); ind++) { - const RuleCubeDimension &ruleCubeDimension = m_cube[ind]; + // create neighbor along translation dimension + const TranslationDimension &translationDimension = + item.GetTranslationDimension(); + if (translationDimension.HasMoreTranslations()) { + CreateNeighbor(item, -1, manager); + } - if (ruleCubeDimension.HasMoreHypo()) { - RuleCube *newEntry = new RuleCube(*this, ind); - queue.Add(newEntry); + // create neighbors along all hypothesis dimensions + for (size_t i = 0; i < item.GetHypothesisDimensions().size(); ++i) { + const HypothesisDimension &dimension = item.GetHypothesisDimensions()[i]; + if (dimension.HasMoreHypo()) { + CreateNeighbor(item, i, manager); } } } -// compute an estimated cost of the principle rule -// (consisting of rule translation scores plus child hypotheses scores) -void RuleCube::CalcScore() +void RuleCube::CreateNeighbor(const RuleCubeItem &item, int dimensionIndex, + ChartManager &manager) { - m_combinedScore = m_transOpt.GetTargetPhrase().GetFutureScore(); - for (size_t ind = 0; ind < m_cube.size(); ind++) { - const RuleCubeDimension &ruleCubeDimension = m_cube[ind]; - - const ChartHypothesis *hypo = ruleCubeDimension.GetHypothesis(); - m_combinedScore += hypo->GetTotalScore(); + RuleCubeItem *newItem = new RuleCubeItem(item, dimensionIndex); + std::pair result = m_covered.insert(newItem); + if (!result.second) { + delete newItem; // already seen it + } else { + if (StaticData::Instance().GetCubePruningLazyScoring()) { + newItem->EstimateScore(); + } else { + newItem->CreateHypothesis(m_transOpt, manager); + } + m_queue.push(newItem); } } -bool RuleCube::operator<(const RuleCube &compare) const -{ - if (&m_transOpt != &compare.m_transOpt) - return &m_transOpt < &compare.m_transOpt; - - bool ret = m_cube < compare.m_cube; - return ret; } - -#ifdef HAVE_BOOST -std::size_t hash_value(const RuleCubeDimension & ruleCubeDimension) -{ - boost::hash hasher; - return hasher(ruleCubeDimension.GetHypothesis()); -} - -#endif -std::ostream& operator<<(std::ostream &out, const RuleCubeDimension &ruleCubeDimension) -{ - out << *ruleCubeDimension.GetHypothesis(); - return out; -} - -std::ostream& operator<<(std::ostream &out, const RuleCube &ruleCube) -{ - out << ruleCube.GetTranslationOption() << endl; - std::vector::const_iterator iter; - for (iter = ruleCube.GetCube().begin(); iter != ruleCube.GetCube().end(); ++iter) { - out << *iter << endl; - } - return out; -} - -} - diff --git a/moses/src/RuleCube.h b/moses/src/RuleCube.h index 4897a44fb..2fa6cdedf 100644 --- a/moses/src/RuleCube.h +++ b/moses/src/RuleCube.h @@ -21,105 +21,119 @@ #pragma once -#include -#include +#if HAVE_CONFIG_H +#include "config.h" +#endif + +#include "RuleCubeItem.h" + +#ifdef HAVE_BOOST +#include +#include +#include +#endif + +#include #include #include -#include -#include "WordsRange.h" -#include "Word.h" -#include "ChartHypothesis.h" +#include namespace Moses { -class CoveredChartSpan; -class ChartTranslationOption; -extern bool g_debug; -class TranslationOptionCollection; -class TranslationOptionList; -class ChartCell; + class ChartCellCollection; -class RuleCube; -class RuleCubeQueue; +class ChartManager; +class ChartTranslationOption; -typedef std::vector HypoList; - -// wrapper around list of hypothese for a particular non-term of a trans opt -class RuleCubeDimension +// Define an ordering between RuleCubeItems based on their scores. This +// is used to order items in the cube's priority queue. +class RuleCubeItemScoreOrderer { - friend std::ostream& operator<<(std::ostream&, const RuleCubeDimension&); - -protected: - size_t m_pos; - const HypoList *m_orderedHypos; - -public: - RuleCubeDimension(size_t pos, const HypoList &orderedHypos) - :m_pos(pos) - ,m_orderedHypos(&orderedHypos) - {} - - size_t IncrementPos() { - return m_pos++; - } - - bool HasMoreHypo() const { - return m_pos + 1 < m_orderedHypos->size(); - } - - const ChartHypothesis *GetHypothesis() const { - return (*m_orderedHypos)[m_pos]; - } - - //! transitive comparison used for adding objects into FactorCollection - bool operator<(const RuleCubeDimension &compare) const { - return GetHypothesis() < compare.GetHypothesis(); - } - - bool operator==(const RuleCubeDimension & compare) const { - return GetHypothesis() == compare.GetHypothesis(); + public: + bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const { + return p->GetScore() < q->GetScore(); } }; -// Stores one dimension in the cube -// (all the hypotheses that match one non terminal) +// Define an ordering between RuleCubeItems based on their positions in the +// cube. This is used to record which positions in the cube have been covered +// during search. +class RuleCubeItemPositionOrderer +{ + public: + bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const { + return *p < *q; + } +}; + +#ifdef HAVE_BOOST +class RuleCubeItemHasher +{ + public: + size_t operator()(const RuleCubeItem *p) const { + size_t seed = 0; + boost::hash_combine(seed, p->GetHypothesisDimensions()); + boost::hash_combine(seed, p->GetTranslationDimension().GetTargetPhrase()); + return seed; + } +}; + +class RuleCubeItemEqualityPred +{ + public: + bool operator()(const RuleCubeItem *p, const RuleCubeItem *q) const { + return p->GetHypothesisDimensions() == q->GetHypothesisDimensions() && + p->GetTranslationDimension() == q->GetTranslationDimension(); + } +}; +#endif + class RuleCube { - friend std::ostream& operator<<(std::ostream&, const RuleCube&); -protected: - const ChartTranslationOption &m_transOpt; - std::vector m_cube; + public: + RuleCube(const ChartTranslationOption &, const ChartCellCollection &, + ChartManager &); - float m_combinedScore; - - RuleCube(const RuleCube ©, size_t ruleCubeDimensionIncr); - void CreateRuleCubeDimension(const CoveredChartSpan *coveredChartSpan, const ChartCellCollection &allChartCells); - - void CalcScore(); - -public: - RuleCube(const ChartTranslationOption &transOpt - , const ChartCellCollection &allChartCells); ~RuleCube(); + float GetTopScore() const { + assert(!m_queue.empty()); + RuleCubeItem *item = m_queue.top(); + return item->GetScore(); + } + + RuleCubeItem *Pop(ChartManager &); + + bool IsEmpty() const { return m_queue.empty(); } + const ChartTranslationOption &GetTranslationOption() const { return m_transOpt; } - const std::vector &GetCube() const { - return m_cube; - } - float GetCombinedScore() const { - return m_combinedScore; - } - void CreateNeighbors(RuleCubeQueue &) const; - - bool operator<(const RuleCube &compare) const; - -}; - -#ifdef HAVE_BOOST -std::size_t hash_value(const RuleCubeDimension &); + private: +#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200) + typedef boost::unordered_set ItemSet; +#else + typedef std::set ItemSet; #endif + typedef std::priority_queue, + RuleCubeItemScoreOrderer + > Queue; + + RuleCube(const RuleCube &); // Not implemented + RuleCube &operator=(const RuleCube &); // Not implemented + + void CreateNeighbors(const RuleCubeItem &, ChartManager &); + void CreateNeighbor(const RuleCubeItem &, int, ChartManager &); + + const ChartTranslationOption &m_transOpt; + ItemSet m_covered; + Queue m_queue; +}; + } diff --git a/moses/src/RuleCubeItem.cpp b/moses/src/RuleCubeItem.cpp new file mode 100644 index 000000000..052159d6e --- /dev/null +++ b/moses/src/RuleCubeItem.cpp @@ -0,0 +1,141 @@ +/*********************************************************************** + Moses - statistical machine translation system + Copyright (C) 2006-2011 University of Edinburgh + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + ***********************************************************************/ + +#include "ChartCell.h" +#include "ChartCellCollection.h" +#include "ChartTranslationOption.h" +#include "ChartTranslationOptionCollection.h" +#include "CoveredChartSpan.h" +#include "RuleCubeItem.h" +#include "RuleCubeQueue.h" +#include "WordsRange.h" +#include "Util.h" + +#ifdef HAVE_BOOST +#include +#endif + +namespace Moses +{ + +#ifdef HAVE_BOOST +std::size_t hash_value(const HypothesisDimension &dimension) +{ + boost::hash hasher; + return hasher(dimension.GetHypothesis()); +} +#endif + +RuleCubeItem::RuleCubeItem(const ChartTranslationOption &transOpt, + const ChartCellCollection &allChartCells) + : m_translationDimension(0, + transOpt.GetTargetPhraseCollection().GetCollection()) + , m_hypothesis(0) +{ + const CoveredChartSpan *lastCCS = &transOpt.GetLastCoveredChartSpan(); + CreateHypothesisDimensions(lastCCS, allChartCells); +} + +// create the RuleCube from an existing one, differing only in one dimension +RuleCubeItem::RuleCubeItem(const RuleCubeItem ©, int hypoDimensionIncr) + : m_translationDimension(copy.m_translationDimension) + , m_hypothesisDimensions(copy.m_hypothesisDimensions) + , m_hypothesis(0) +{ + if (hypoDimensionIncr == -1) { + m_translationDimension.IncrementPos(); + } else { + HypothesisDimension &dimension = m_hypothesisDimensions[hypoDimensionIncr]; + dimension.IncrementPos(); + } +} + +RuleCubeItem::~RuleCubeItem() +{ + delete m_hypothesis; +} + +void RuleCubeItem::EstimateScore() +{ + m_score = m_translationDimension.GetTargetPhrase()->GetFutureScore(); + std::vector::const_iterator p; + for (p = m_hypothesisDimensions.begin(); + p != m_hypothesisDimensions.end(); ++p) { + m_score += p->GetHypothesis()->GetTotalScore(); + } +} + +void RuleCubeItem::CreateHypothesis(const ChartTranslationOption &transOpt, + ChartManager &manager) +{ + m_hypothesis = new ChartHypothesis(transOpt, *this, manager); + m_hypothesis->CalcScore(); + m_score = m_hypothesis->GetTotalScore(); +} + +ChartHypothesis *RuleCubeItem::ReleaseHypothesis() +{ + assert(m_hypothesis); + ChartHypothesis *hypo = m_hypothesis; + m_hypothesis = 0; + return hypo; +} + +// for each non-terminal, create a ordered list of matching hypothesis from the +// chart +void RuleCubeItem::CreateHypothesisDimensions( + const CoveredChartSpan *coveredChartSpan, + const ChartCellCollection &allChartCells) +{ + // recurse through the linked list of source side non-terminals and terminals + const CoveredChartSpan *prev = coveredChartSpan->GetPrevCoveredChartSpan(); + if (prev) { + CreateHypothesisDimensions(prev, allChartCells); + } + + // only deal with non-terminals + if (coveredChartSpan->IsNonTerminal()) { + // get the essential information about the non-terminal: + // span covered by child + const WordsRange &childRange = coveredChartSpan->GetWordsRange(); + // list of all hypos for that span + const ChartCell &childCell = allChartCells.Get(childRange); + // target (sic!) non-terminal label + const Word &nonTerm = coveredChartSpan->GetSourceWord(); + + // there have to be hypothesis with the desired non-terminal + // (otherwise the rule would not be considered) + assert(!childCell.GetSortedHypotheses(nonTerm).empty()); + + // create a list of hypotheses that match the non-terminal + HypothesisDimension dimension(0, childCell.GetSortedHypotheses(nonTerm)); + // add them to the vector for such lists + m_hypothesisDimensions.push_back(dimension); + } +} + +bool RuleCubeItem::operator<(const RuleCubeItem &compare) const +{ + if (m_translationDimension == compare.m_translationDimension) { + return m_hypothesisDimensions < compare.m_hypothesisDimensions; + } + return m_translationDimension < compare.m_translationDimension; +} + +} diff --git a/moses/src/RuleCubeItem.h b/moses/src/RuleCubeItem.h new file mode 100644 index 000000000..5ff9a28fa --- /dev/null +++ b/moses/src/RuleCubeItem.h @@ -0,0 +1,148 @@ +/*********************************************************************** + Moses - statistical machine translation system + Copyright (C) 2006-2011 University of Edinburgh + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + ***********************************************************************/ + +#pragma once + +#if HAVE_CONFIG_H +#include "config.h" +#endif + +#include + +namespace Moses +{ + +class ChartCellCollection; +class ChartHypothesis; +class ChartManager; +class ChartTranslationOption; +class CoveredChartSpan; +class TargetPhrase; + +typedef std::vector HypoList; + +// wrapper around list of target phrase translation options +class TranslationDimension +{ + public: + TranslationDimension(size_t pos, + const std::vector &orderedTargetPhrases) + : m_pos(pos) + , m_orderedTargetPhrases(&orderedTargetPhrases) + {} + + size_t IncrementPos() { return m_pos++; } + + bool HasMoreTranslations() const { + return m_pos+1 < m_orderedTargetPhrases->size(); + } + + const TargetPhrase *GetTargetPhrase() const { + return (*m_orderedTargetPhrases)[m_pos]; + } + + bool operator<(const TranslationDimension &compare) const { + return GetTargetPhrase() < compare.GetTargetPhrase(); + } + + bool operator==(const TranslationDimension &compare) const { + return GetTargetPhrase() == compare.GetTargetPhrase(); + } + + private: + size_t m_pos; + const std::vector *m_orderedTargetPhrases; +}; + + +// wrapper around list of hypotheses for a particular non-term of a trans opt +class HypothesisDimension +{ +public: + HypothesisDimension(size_t pos, const HypoList &orderedHypos) + : m_pos(pos) + , m_orderedHypos(&orderedHypos) + {} + + size_t IncrementPos() { return m_pos++; } + + bool HasMoreHypo() const { + return m_pos+1 < m_orderedHypos->size(); + } + + const ChartHypothesis *GetHypothesis() const { + return (*m_orderedHypos)[m_pos]; + } + + bool operator<(const HypothesisDimension &compare) const { + return GetHypothesis() < compare.GetHypothesis(); + } + + bool operator==(const HypothesisDimension &compare) const { + return GetHypothesis() == compare.GetHypothesis(); + } + +private: + size_t m_pos; + const HypoList *m_orderedHypos; +}; + +#ifdef HAVE_BOOST +std::size_t hash_value(const HypothesisDimension &); +#endif + +class RuleCubeItem +{ + public: + RuleCubeItem(const ChartTranslationOption &, const ChartCellCollection &); + RuleCubeItem(const RuleCubeItem &, int); + ~RuleCubeItem(); + + const TranslationDimension &GetTranslationDimension() const { + return m_translationDimension; + } + + const std::vector &GetHypothesisDimensions() const { + return m_hypothesisDimensions; + } + + float GetScore() const { return m_score; } + + void EstimateScore(); + + void CreateHypothesis(const ChartTranslationOption &, ChartManager &); + + ChartHypothesis *ReleaseHypothesis(); + + bool operator<(const RuleCubeItem &) const; + + private: + RuleCubeItem(const RuleCubeItem &); // Not implemented + RuleCubeItem &operator=(const RuleCubeItem &); // Not implemented + + void CreateHypothesisDimensions(const CoveredChartSpan *, + const ChartCellCollection &); + + TranslationDimension m_translationDimension; + std::vector m_hypothesisDimensions; + ChartHypothesis *m_hypothesis; + float m_score; +}; + +} diff --git a/moses/src/RuleCubeQueue.cpp b/moses/src/RuleCubeQueue.cpp index 4925a0d1c..89020a4e5 100644 --- a/moses/src/RuleCubeQueue.cpp +++ b/moses/src/RuleCubeQueue.cpp @@ -21,42 +21,48 @@ #include "RuleCubeQueue.h" -#include "Util.h" - -using namespace std; +#include "RuleCubeItem.h" +#include "StaticData.h" namespace Moses { + RuleCubeQueue::~RuleCubeQueue() { - RemoveAllInColl(m_uniqueEntry); + while (!m_queue.empty()) { + RuleCube *cube = m_queue.top(); + m_queue.pop(); + delete cube; + } } -bool RuleCubeQueue::Add(RuleCube *ruleCube) +void RuleCubeQueue::Add(RuleCube *ruleCube) { - pair inserted = m_uniqueEntry.insert(ruleCube); + m_queue.push(ruleCube); +} - if (inserted.second) { - // inserted - m_sortedByScore.push(ruleCube); +ChartHypothesis *RuleCubeQueue::Pop() +{ + // pop the most promising rule cube + RuleCube *cube = m_queue.top(); + m_queue.pop(); + + // pop the most promising item from the cube and get the corresponding + // hypothesis + RuleCubeItem *item = cube->Pop(m_manager); + if (StaticData::Instance().GetCubePruningLazyScoring()) { + item->CreateHypothesis(cube->GetTranslationOption(), m_manager); + } + ChartHypothesis *hypo = item->ReleaseHypothesis(); + + // if the cube contains more items then push it back onto the queue + if (!cube->IsEmpty()) { + m_queue.push(cube); } else { - // already there - //cerr << "already there\n"; - delete ruleCube; + delete cube; } - //assert(m_uniqueEntry.size() == m_sortedByScore.size()); - - return inserted.second; -} - -RuleCube *RuleCubeQueue::Pop() -{ - RuleCube *ruleCube = m_sortedByScore.top(); - m_sortedByScore.pop(); - - return ruleCube; + return hypo; } } - diff --git a/moses/src/RuleCubeQueue.h b/moses/src/RuleCubeQueue.h index 8fbd27958..5ef3b3af2 100644 --- a/moses/src/RuleCubeQueue.h +++ b/moses/src/RuleCubeQueue.h @@ -18,85 +18,49 @@ License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA ***********************************************************************/ -#include -#include -#include + +#pragma once + +#if HAVE_CONFIG_H +#include "config.h" +#endif + #include "RuleCube.h" -#ifdef HAVE_BOOST -#include -#include -#include -#endif +#include +#include namespace Moses { -#ifdef HAVE_BOOST -class RuleCubeUniqueHasher +class ChartManager; + +// Define an ordering between RuleCube based on their best item scores. This +// is used to order items in the priority queue. +class RuleCubeOrderer { -public: - size_t operator()(const RuleCube * p) const { - size_t seed = 0; - boost::hash_combine(seed, &(p->GetTranslationOption())); - boost::hash_combine(seed, p->GetCube()); - return seed; + public: + bool operator()(const RuleCube *p, const RuleCube *q) const { + return p->GetTopScore() < q->GetTopScore(); } }; -class RuleCubeUniqueEqualityPred -{ -public: - bool operator()(const RuleCube * p, const RuleCube * q) const { - return ((&(p->GetTranslationOption()) == &(q->GetTranslationOption())) - && (p->GetCube() == q->GetCube())); - } -}; -#endif - -class RuleCubeUniqueOrderer -{ -public: - bool operator()(const RuleCube* entryA, const RuleCube* entryB) const { - return (*entryA) < (*entryB); - } -}; - -class RuleCubeScoreOrderer -{ -public: - bool operator()(const RuleCube* entryA, const RuleCube* entryB) const { - return (entryA->GetCombinedScore() < entryB->GetCombinedScore()); - } -}; - - class RuleCubeQueue { -protected: -#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200) - typedef boost::unordered_set UniqueCubeEntry; -#else - typedef std::set UniqueCubeEntry; -#endif - UniqueCubeEntry m_uniqueEntry; - - typedef std::priority_queue, RuleCubeScoreOrderer> SortedByScore; - SortedByScore m_sortedByScore; - - -public: + public: + RuleCubeQueue(ChartManager &manager) : m_manager(manager) {} ~RuleCubeQueue(); - bool IsEmpty() const { - return m_sortedByScore.empty(); - } - - RuleCube *Pop(); - bool Add(RuleCube *ruleCube); -}; - + void Add(RuleCube *); + ChartHypothesis *Pop(); + bool IsEmpty() const { return m_queue.empty(); } + + private: + typedef std::priority_queue, + RuleCubeOrderer > Queue; + + Queue m_queue; + ChartManager &m_manager; }; +} diff --git a/moses/src/StaticData.cpp b/moses/src/StaticData.cpp index a479e40b4..d386ee4d7 100644 --- a/moses/src/StaticData.cpp +++ b/moses/src/StaticData.cpp @@ -333,6 +333,8 @@ bool StaticData::LoadData(Parameter *parameter) m_cubePruningDiversity = (m_parameter->GetParam("cube-pruning-diversity").size() > 0) ? Scan(m_parameter->GetParam("cube-pruning-diversity")[0]) : DEFAULT_CUBE_PRUNING_DIVERSITY; + SetBooleanParameter(&m_cubePruningLazyScoring, "cube-pruning-lazy-scoring", false); + // unknown word processing SetBooleanParameter( &m_dropUnknown, "drop-unknown", false ); diff --git a/moses/src/StaticData.h b/moses/src/StaticData.h index 9a3149f67..37fdf3dc1 100644 --- a/moses/src/StaticData.h +++ b/moses/src/StaticData.h @@ -195,6 +195,7 @@ protected: size_t m_cubePruningPopLimit; size_t m_cubePruningDiversity; + bool m_cubePruningLazyScoring; size_t m_ruleLimit; @@ -313,6 +314,9 @@ public: size_t GetCubePruningDiversity() const { return m_cubePruningDiversity; } + bool GetCubePruningLazyScoring() const { + return m_cubePruningLazyScoring; + } size_t IsPathRecoveryEnabled() const { return m_recoverPath; } diff --git a/moses/src/TargetPhraseCollection.cpp b/moses/src/TargetPhraseCollection.cpp index 398acf345..38570f8d5 100644 --- a/moses/src/TargetPhraseCollection.cpp +++ b/moses/src/TargetPhraseCollection.cpp @@ -55,6 +55,25 @@ void TargetPhraseCollection::Prune(bool adhereTableLimit, size_t tableLimit) } } +void TargetPhraseCollection::Sort(bool adhereTableLimit, size_t tableLimit) +{ + std::vector::iterator iterMiddle; + iterMiddle = (tableLimit == 0 || m_collection.size() < tableLimit) + ? m_collection.end() + : m_collection.begin()+tableLimit; + + std::partial_sort(m_collection.begin(), iterMiddle, m_collection.end(), + CompareTargetPhrase()); + + if (adhereTableLimit && m_collection.size() > tableLimit) { + for (size_t i = tableLimit; i < m_collection.size(); ++i) { + TargetPhrase *targetPhrase = m_collection[i]; + delete targetPhrase; + } + m_collection.erase(m_collection.begin()+tableLimit, m_collection.end()); + } +} + } diff --git a/moses/src/TargetPhraseCollection.h b/moses/src/TargetPhraseCollection.h index 849401848..f4124f458 100644 --- a/moses/src/TargetPhraseCollection.h +++ b/moses/src/TargetPhraseCollection.h @@ -57,6 +57,8 @@ public: RemoveAllInColl(m_collection); } + const std::vector &GetCollection() const { return m_collection; } + //! divide collection into 2 buckets using std::nth_element, the top & bottom according to table limit void NthElement(size_t tableLimit); @@ -74,6 +76,7 @@ public: } void Prune(bool adhereTableLimit, size_t tableLimit); + void Sort(bool adhereTableLimit, size_t tableLimit); };