From ea9c721733652f38434a5b27231775eb3ab7268a Mon Sep 17 00:00:00 2001 From: Phil Williams Date: Mon, 6 Feb 2012 23:54:01 +0000 Subject: [PATCH] moses_chart: reduce memory usage by creating one ChartTranslationOptionList per sentence instead of one per cell, and reduce object churn by recycling ChartTranslationOption objects. --- .../ChartRuleLookupManagerCYKPlus.cpp | 4 +- .../ChartRuleLookupManagerCYKPlus.h | 3 +- .../ChartRuleLookupManagerMemory.cpp | 5 +- .../ChartRuleLookupManagerOnDisk.cpp | 7 +- moses/src/ChartCell.cpp | 6 +- moses/src/ChartManager.cpp | 12 +- moses/src/ChartTranslationOption.cpp | 14 +- moses/src/ChartTranslationOption.h | 29 ++-- .../src/ChartTranslationOptionCollection.cpp | 78 +++-------- moses/src/ChartTranslationOptionCollection.h | 20 +-- moses/src/ChartTranslationOptionList.cpp | 127 ++++++++---------- moses/src/ChartTranslationOptionList.h | 106 ++++----------- moses/src/Scope3Parser/Parser.cpp | 3 +- moses/src/Scope3Parser/Parser.h | 8 +- 14 files changed, 147 insertions(+), 275 deletions(-) diff --git a/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.cpp b/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.cpp index 074e7b1e9..574572be8 100644 --- a/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.cpp +++ b/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.cpp @@ -34,7 +34,7 @@ namespace Moses void ChartRuleLookupManagerCYKPlus::AddCompletedRule( const DottedRule &dottedRule, const TargetPhraseCollection &tpc, - size_t ruleLimit, + const WordsRange &range, ChartTranslationOptionList &outColl) { // Determine the rule's rank. @@ -61,7 +61,7 @@ void ChartRuleLookupManagerCYKPlus::AddCompletedRule( } // Add the (TargetPhraseCollection, StackVec) pair to the collection. - outColl.Add(tpc, m_stackVec, ruleLimit); + outColl.Add(tpc, m_stackVec, range); } } // namespace Moses diff --git a/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.h b/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.h index 240d101b7..dbee00b6f 100644 --- a/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.h +++ b/moses/src/CYKPlusParser/ChartRuleLookupManagerCYKPlus.h @@ -29,6 +29,7 @@ namespace Moses class DottedRule; class TargetPhraseCollection; +class WordsRange; class ChartRuleLookupManagerCYKPlus : public ChartRuleLookupManager { @@ -41,7 +42,7 @@ class ChartRuleLookupManagerCYKPlus : public ChartRuleLookupManager void AddCompletedRule( const DottedRule &dottedRule, const TargetPhraseCollection &tpc, - size_t ruleLimit, + const WordsRange &range, ChartTranslationOptionList &outColl); StackVec m_stackVec; diff --git a/moses/src/CYKPlusParser/ChartRuleLookupManagerMemory.cpp b/moses/src/CYKPlusParser/ChartRuleLookupManagerMemory.cpp index d3bb3fa81..d180905aa 100644 --- a/moses/src/CYKPlusParser/ChartRuleLookupManagerMemory.cpp +++ b/moses/src/CYKPlusParser/ChartRuleLookupManagerMemory.cpp @@ -151,7 +151,6 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection( DottedRuleList &rules = dottedRuleCol.Get(relEndPos + 1); // look up target sides for the rules - const size_t ruleLimit = StaticData::Instance().GetRuleLimit(); DottedRuleList::const_iterator iterRule; for (iterRule = rules.begin(); iterRule != rules.end(); ++iterRule) { const DottedRuleInMemory &dottedRule = **iterRule; @@ -162,13 +161,13 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection( // add the fully expanded rule (with lexical target side) if (tpc != NULL) { - AddCompletedRule(dottedRule, *tpc, ruleLimit, outColl); + AddCompletedRule(dottedRule, *tpc, range, outColl); } } dottedRuleCol.Clear(relEndPos+1); - outColl.CreateChartRules(ruleLimit); + outColl.ShrinkToLimit(); } // Given a partial rule application ending at startPos-1 and given the sets of diff --git a/moses/src/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp b/moses/src/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp index 70549e636..08db10472 100644 --- a/moses/src/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp +++ b/moses/src/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp @@ -83,9 +83,6 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( const WordsRange &range, ChartTranslationOptionList &outColl) { - const StaticData &staticData = StaticData::Instance(); - size_t rulesLimit = staticData.GetRuleLimit(); - size_t relEndPos = range.GetEndPos() - range.GetStartPos(); size_t absEndPos = range.GetEndPos(); @@ -260,7 +257,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( CHECK(targetPhraseCollection); if (!targetPhraseCollection->IsEmpty()) { AddCompletedRule(prevDottedRule, *targetPhraseCollection, - rulesLimit, outColl); + range, outColl); } } // if (node) @@ -271,7 +268,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( } } // for (size_t ind = 0; ind < savedNodeColl.size(); ++ind) - outColl.CreateChartRules(rulesLimit); + outColl.ShrinkToLimit(); //cerr << numDerivations << " "; } diff --git a/moses/src/ChartCell.cpp b/moses/src/ChartCell.cpp index 4f91d93c0..89341d8e7 100644 --- a/moses/src/ChartCell.cpp +++ b/moses/src/ChartCell.cpp @@ -88,10 +88,8 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList RuleCubeQueue queue(m_manager); // add all trans opt into queue. using only 1st child node. - ChartTranslationOptionList::const_iterator iterList; - for (iterList = transOptList.begin(); iterList != transOptList.end(); ++iterList) - { - const ChartTranslationOption &transOpt = **iterList; + for (size_t i = 0; i < transOptList.GetSize(); ++i) { + const ChartTranslationOption &transOpt = transOptList.Get(i); RuleCube *ruleCube = new RuleCube(transOpt, allChartCells, m_manager); queue.Add(ruleCube); } diff --git a/moses/src/ChartManager.cpp b/moses/src/ChartManager.cpp index b2ed3a435..a0737ffcb 100644 --- a/moses/src/ChartManager.cpp +++ b/moses/src/ChartManager.cpp @@ -85,25 +85,19 @@ void ChartManager::ProcessSentence() for (size_t startPos = 0; startPos <= size-width; ++startPos) { size_t endPos = startPos + width - 1; WordsRange range(startPos, endPos); - //TRACE_ERR(" " << range << "="); // create trans opt - m_transOptColl.CreateTranslationOptionsForRange(startPos, endPos); - //if (g_debug) - // cerr << m_transOptColl.GetTranslationOptionList(WordsRange(startPos, endPos)); + m_transOptColl.CreateTranslationOptionsForRange(range); // decode ChartCell &cell = m_hypoStackColl.Get(range); - cell.ProcessSentence(m_transOptColl.GetTranslationOptionList(range) + cell.ProcessSentence(m_transOptColl.GetTranslationOptionList() ,m_hypoStackColl); + m_transOptColl.Clear(); cell.PruneToSize(); cell.CleanupArcList(); cell.SortHypotheses(); - - //cerr << cell.GetSize(); - //cerr << cell << endl; - //cell.OutputSizes(cerr); } } diff --git a/moses/src/ChartTranslationOption.cpp b/moses/src/ChartTranslationOption.cpp index 997a3445a..792bfde82 100644 --- a/moses/src/ChartTranslationOption.cpp +++ b/moses/src/ChartTranslationOption.cpp @@ -24,19 +24,21 @@ namespace Moses { -void ChartTranslationOption::CalcEstimateOfBestScore() +float ChartTranslationOption::CalcEstimateOfBestScore( + const TargetPhraseCollection &tpc, + const StackVec &stackVec) { - const TargetPhrase &targetPhrase = **(m_targetPhraseCollection.begin()); - m_estimateOfBestScore = targetPhrase.GetFutureScore(); - - for (StackVec::const_iterator p = m_stackVec.begin(); p != m_stackVec.end(); + const TargetPhrase &targetPhrase = **(tpc.begin()); + float estimateOfBestScore = targetPhrase.GetFutureScore(); + for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) { const HypoList *stack = *p; assert(stack); assert(!stack->empty()); const ChartHypothesis &bestHypo = **(stack->begin()); - m_estimateOfBestScore += bestHypo.GetTotalScore(); + estimateOfBestScore += bestHypo.GetTotalScore(); } + return estimateOfBestScore; } } diff --git a/moses/src/ChartTranslationOption.h b/moses/src/ChartTranslationOption.h index 888d7138d..d17f00427 100644 --- a/moses/src/ChartTranslationOption.h +++ b/moses/src/ChartTranslationOption.h @@ -37,25 +37,26 @@ class ChartTranslationOption public: ChartTranslationOption(const TargetPhraseCollection &targetPhraseColl, const StackVec &stackVec, - const WordsRange &wordsRange) + const WordsRange &wordsRange, + float score) : m_stackVec(stackVec) - , m_targetPhraseCollection(targetPhraseColl) - , m_wordsRange(wordsRange) - , m_estimateOfBestScore(0) - { - CalcEstimateOfBestScore(); - } + , m_targetPhraseCollection(&targetPhraseColl) + , m_wordsRange(&wordsRange) + , m_estimateOfBestScore(score) {} ~ChartTranslationOption() {} + static float CalcEstimateOfBestScore(const TargetPhraseCollection &, + const StackVec &); + const StackVec &GetStackVec() const { return m_stackVec; } const TargetPhraseCollection &GetTargetPhraseCollection() const { - return m_targetPhraseCollection; + return *m_targetPhraseCollection; } const WordsRange &GetSourceWordsRange() const { - return m_wordsRange; + return *m_wordsRange; } // return an estimate of the best score possible with this translation option. @@ -64,14 +65,10 @@ class ChartTranslationOption inline float GetEstimateOfBestScore() const { return m_estimateOfBestScore; } private: - // not implemented - ChartTranslationOption &operator=(const ChartTranslationOption &); - void CalcEstimateOfBestScore(); - - const StackVec m_stackVec; - const TargetPhraseCollection &m_targetPhraseCollection; - const WordsRange &m_wordsRange; + StackVec m_stackVec; + const TargetPhraseCollection *m_targetPhraseCollection; + const WordsRange *m_wordsRange; float m_estimateOfBestScore; }; diff --git a/moses/src/ChartTranslationOptionCollection.cpp b/moses/src/ChartTranslationOptionCollection.cpp index 19b17499b..ea9ce9426 100644 --- a/moses/src/ChartTranslationOptionCollection.cpp +++ b/moses/src/ChartTranslationOptionCollection.cpp @@ -41,16 +41,8 @@ ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType con ,m_decodeGraphList(system->GetDecodeGraphs()) ,m_hypoStackColl(hypoStackColl) ,m_ruleLookupManagers(ruleLookupManagers) - ,m_collection(source.GetSize()) + ,m_translationOptionList(StaticData::Instance().GetRuleLimit()) { - // create 2-d vector - size_t size = source.GetSize(); - for (size_t startPos = 0 ; startPos < size ; ++startPos) { - m_collection[startPos].reserve(size-startPos); - for (size_t endPos = startPos ; endPos < size ; ++endPos) { - m_collection[startPos].push_back( ChartTranslationOptionList(WordsRange(startPos, endPos)) ); - } - } } ChartTranslationOptionCollection::~ChartTranslationOptionCollection() @@ -60,13 +52,12 @@ ChartTranslationOptionCollection::~ChartTranslationOptionCollection() } void ChartTranslationOptionCollection::CreateTranslationOptionsForRange( - size_t startPos - , size_t endPos) + const WordsRange &wordsRange) { - ChartTranslationOptionList &chartRuleColl = GetTranslationOptionList(startPos, endPos); - const WordsRange &wordsRange = chartRuleColl.GetSourceRange(); - assert(m_decodeGraphList.size() == m_ruleLookupManagers.size()); + + m_translationOptionList.Clear(); + std::vector ::const_iterator iterDecodeGraph; std::vector ::const_iterator iterRuleLookupManagers = m_ruleLookupManagers.begin(); for (iterDecodeGraph = m_decodeGraphList.begin(); iterDecodeGraph != m_decodeGraphList.end(); ++iterDecodeGraph, ++iterRuleLookupManagers) { @@ -74,64 +65,31 @@ void ChartTranslationOptionCollection::CreateTranslationOptionsForRange( assert(decodeGraph.GetSize() == 1); ChartRuleLookupManager &ruleLookupManager = **iterRuleLookupManagers; size_t maxSpan = decodeGraph.GetMaxChartSpan(); - if (maxSpan == 0 || (endPos-startPos+1) <= maxSpan) { - ruleLookupManager.GetChartRuleCollection(wordsRange, chartRuleColl); + if (maxSpan == 0 || wordsRange.GetNumWordsCovered() <= maxSpan) { + ruleLookupManager.GetChartRuleCollection(wordsRange, m_translationOptionList); } } - if (startPos == endPos && startPos != 0 && startPos != m_source.GetSize()-1) { + if (wordsRange.GetNumWordsCovered() == 1 && wordsRange.GetStartPos() != 0 && wordsRange.GetStartPos() != m_source.GetSize()-1) { bool alwaysCreateDirectTranslationOption = StaticData::Instance().IsAlwaysCreateDirectTranslationOption(); - if (chartRuleColl.GetSize() == 0 || alwaysCreateDirectTranslationOption) { + if (m_translationOptionList.GetSize() == 0 || alwaysCreateDirectTranslationOption) { // create unknown words for 1 word coverage where we don't have any trans options - const Word &sourceWord = m_source.GetWord(startPos); - ProcessOneUnknownWord(sourceWord, startPos); + const Word &sourceWord = m_source.GetWord(wordsRange.GetStartPos()); + ProcessOneUnknownWord(sourceWord, wordsRange); } } - chartRuleColl.ApplyThreshold(); -} - -ChartTranslationOptionList &ChartTranslationOptionCollection::GetTranslationOptionList(size_t startPos, size_t endPos) -{ - size_t sizeVec = m_collection[startPos].size(); - CHECK(endPos-startPos < sizeVec); - return m_collection[startPos][endPos - startPos]; -} -const ChartTranslationOptionList &ChartTranslationOptionCollection::GetTranslationOptionList(size_t startPos, size_t endPos) const -{ - size_t sizeVec = m_collection[startPos].size(); - CHECK(endPos-startPos < sizeVec); - return m_collection[startPos][endPos - startPos]; -} - -std::ostream& operator<<(std::ostream &out, const ChartTranslationOptionCollection &coll) -{ - std::vector< std::vector< ChartTranslationOptionList > >::const_iterator iterOuter; - for (iterOuter = coll.m_collection.begin(); iterOuter != coll.m_collection.end(); ++iterOuter) { - const std::vector< ChartTranslationOptionList > &vecInner = *iterOuter; - std::vector< ChartTranslationOptionList >::const_iterator iterInner; - - for (iterInner = vecInner.begin(); iterInner != vecInner.end(); ++iterInner) { - const ChartTranslationOptionList &list = *iterInner; - out << list.GetSourceRange() << " = " << list.GetSize() << std::endl; - } - } - - - return out; + m_translationOptionList.ApplyThreshold(); } //! special handling of ONE unknown words. -void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceWord, size_t sourcePos, size_t /* length */) +void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceWord, const WordsRange &range) { // unknown word, add as trans opt const StaticData &staticData = StaticData::Instance(); const UnknownWordPenaltyProducer *unknownWordPenaltyProducer = m_system->GetUnknownWordPenaltyProducer(); vector wordPenaltyScore(1, -0.434294482); // TODO what is this number? - ChartTranslationOptionList &transOptColl = GetTranslationOptionList(sourcePos, sourcePos); - const WordsRange &range = transOptColl.GetSourceRange(); - const ChartCell &chartCell = m_hypoStackColl.Get(range); const ChartCellLabel &sourceWordLabel = chartCell.GetSourceWordLabel(); @@ -186,10 +144,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW targetPhrase->SetTargetLHS(targetLHS); // chart rule - ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc - , m_emptyStackVec - , range); - transOptColl.Add(chartRule); + m_translationOptionList.Add(*tpc, m_emptyStackVec, range); } // for (iterLHS } else { // drop source word. create blank trans opt @@ -215,10 +170,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW targetPhrase->SetTargetLHS(targetLHS); // chart rule - ChartTranslationOption *chartRule = new ChartTranslationOption(*tpc - , m_emptyStackVec - , range); - transOptColl.Add(chartRule); + m_translationOptionList.Add(*tpc, m_emptyStackVec, range); } } } diff --git a/moses/src/ChartTranslationOptionCollection.h b/moses/src/ChartTranslationOptionCollection.h index f0e99ea0e..4f5b84062 100644 --- a/moses/src/ChartTranslationOptionCollection.h +++ b/moses/src/ChartTranslationOptionCollection.h @@ -38,7 +38,6 @@ class ChartCellCollection; class ChartTranslationOptionCollection { - friend std::ostream& operator<<(std::ostream&, const ChartTranslationOptionCollection&); protected: const InputType &m_source; const TranslationSystem* m_system; @@ -46,18 +45,13 @@ protected: const ChartCellCollection &m_hypoStackColl; const std::vector &m_ruleLookupManagers; - std::vector< std::vector< ChartTranslationOptionList > > m_collection; /*< contains translation options */ + ChartTranslationOptionList m_translationOptionList; std::vector m_unksrcs; std::list m_cacheTargetPhraseCollection; StackVec m_emptyStackVec; - ChartTranslationOptionList &GetTranslationOptionList(size_t startPos, size_t endPos); - const ChartTranslationOptionList &GetTranslationOptionList(size_t startPos, size_t endPos) const; - - //! special handling of ONE unknown words. - virtual void ProcessOneUnknownWord(const Word &sourceWord - , size_t sourcePos, size_t length = 1); + virtual void ProcessOneUnknownWord(const Word &, const WordsRange &); public: ChartTranslationOptionCollection(InputType const& source @@ -65,14 +59,14 @@ public: , const ChartCellCollection &hypoStackColl , const std::vector &ruleLookupManagers); virtual ~ChartTranslationOptionCollection(); - void CreateTranslationOptionsForRange(size_t startPos - , size_t endPos); + void CreateTranslationOptionsForRange(const WordsRange &); - const ChartTranslationOptionList &GetTranslationOptionList(const WordsRange &range) const { - return GetTranslationOptionList(range.GetStartPos(), range.GetEndPos()); + const ChartTranslationOptionList &GetTranslationOptionList() const { + return m_translationOptionList; } + void Clear() { m_translationOptionList.Clear(); } + }; } - diff --git a/moses/src/ChartTranslationOptionList.cpp b/moses/src/ChartTranslationOptionList.cpp index 1444cb7b4..eadd4b688 100644 --- a/moses/src/ChartTranslationOptionList.cpp +++ b/moses/src/ChartTranslationOptionList.cpp @@ -1,4 +1,3 @@ -// $Id$ /*********************************************************************** Moses - factored phrase-based language decoder Copyright (C) 2010 Hieu Hoang @@ -29,14 +28,10 @@ namespace Moses { -#ifdef USE_HYPO_POOL -ObjectPool ChartTranslationOptionList::s_objectPool("ChartTranslationOptionList", 3000); -#endif - -ChartTranslationOptionList::ChartTranslationOptionList(const WordsRange &range) - :m_range(range) +ChartTranslationOptionList::ChartTranslationOptionList(size_t ruleLimit) + : m_size(0) + , m_ruleLimit(ruleLimit) { - m_collection.reserve(200); m_scoreThreshold = std::numeric_limits::infinity(); } @@ -45,6 +40,12 @@ ChartTranslationOptionList::~ChartTranslationOptionList() RemoveAllInColl(m_collection); } +void ChartTranslationOptionList::Clear() +{ + m_size = 0; + m_scoreThreshold = std::numeric_limits::infinity(); +} + class ChartTranslationOptionOrderer { public: @@ -53,80 +54,74 @@ public: } }; -void ChartTranslationOptionList::Add(const TargetPhraseCollection &targetPhraseCollection - , const StackVec &stackVec - , size_t ruleLimit) +void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc, + const StackVec &stackVec, + const WordsRange &range) { - if (targetPhraseCollection.IsEmpty()) { + if (tpc.IsEmpty()) { return; } - if (m_collection.size() < ruleLimit) { - // not yet filled out quota. add everything - ChartTranslationOption *option = new ChartTranslationOption( - targetPhraseCollection, stackVec, m_range); - m_collection.push_back(option); - float score = option->GetEstimateOfBestScore(); + float score = ChartTranslationOption::CalcEstimateOfBestScore(tpc, stackVec); + + // If the rule limit has already been reached then don't add the option + // unless it is better than at least one existing option. + if (m_size > m_ruleLimit && score < m_scoreThreshold) { + return; + } + + // Add the option to the list. + if (m_size == m_collection.size()) { + // m_collection has reached capacity: create a new object. + m_collection.push_back(new ChartTranslationOption(tpc, stackVec, + range, score)); + } else { + // Overwrite an unused object. + *(m_collection[m_size]) = ChartTranslationOption(tpc, stackVec, + range, score); + } + ++m_size; + + // If the rule limit hasn't been exceeded then update the threshold. + if (m_size <= m_ruleLimit) { m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold; } - else { - // full but not bursting. add if better than worst score - ChartTranslationOption option(targetPhraseCollection, stackVec, m_range); - float score = option.GetEstimateOfBestScore(); - if (score > m_scoreThreshold) { - // dynamic allocation deferred until here on the assumption that most - // options will score below the threshold. - m_collection.push_back(new ChartTranslationOption(option)); - } - } - // prune if bursting - if (m_collection.size() > ruleLimit * 2) { - std::nth_element(m_collection.begin() - , m_collection.begin() + ruleLimit - , m_collection.end() - , ChartTranslationOptionOrderer()); - // delete the bottom half - for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) { - // make the best score of bottom half the score threshold - float score = m_collection[ind]->GetEstimateOfBestScore(); - m_scoreThreshold = (score > m_scoreThreshold) ? score : m_scoreThreshold; - delete m_collection[ind]; - } - m_collection.resize(ruleLimit); + // Prune if bursting + if (m_size == m_ruleLimit * 2) { + std::nth_element(m_collection.begin(), + m_collection.begin() + m_ruleLimit - 1, + m_collection.begin() + m_size, + ChartTranslationOptionOrderer()); + m_scoreThreshold = m_collection[m_ruleLimit-1]->GetEstimateOfBestScore(); + m_size = m_ruleLimit; } } -void ChartTranslationOptionList::Add(ChartTranslationOption *transOpt) +void ChartTranslationOptionList::ShrinkToLimit() { - CHECK(transOpt); - m_collection.push_back(transOpt); -} - -void ChartTranslationOptionList::CreateChartRules(size_t ruleLimit) -{ - if (m_collection.size() > ruleLimit) { - std::nth_element(m_collection.begin() - , m_collection.begin() + ruleLimit - , m_collection.end() - , ChartTranslationOptionOrderer()); - - // delete the bottom half - for (size_t ind = ruleLimit; ind < m_collection.size(); ++ind) { - delete m_collection[ind]; - } - m_collection.resize(ruleLimit); + if (m_size > m_ruleLimit) { + // Something's gone wrong if the list has grown to m_ruleLimit * 2 + // without being pruned. + assert(m_size < m_ruleLimit * 2); + // Reduce the list to the best m_ruleLimit options. The remaining + // options can be overwritten on subsequent calls to Add(). + std::nth_element(m_collection.begin(), + m_collection.begin()+m_ruleLimit, + m_collection.begin()+m_size, + ChartTranslationOptionOrderer()); + m_size = m_ruleLimit; } } - void ChartTranslationOptionList::ApplyThreshold() { // keep only those over best + threshold float scoreThreshold = -std::numeric_limits::infinity(); + CollType::const_iterator iter; - for (iter = m_collection.begin(); iter != m_collection.end(); ++iter) { + for (iter = m_collection.begin(); iter != m_collection.begin()+m_size; ++iter) { const ChartTranslationOption *transOpt = *iter; float score = transOpt->GetEstimateOfBestScore(); scoreThreshold = (score > scoreThreshold) ? score : scoreThreshold; @@ -135,14 +130,10 @@ void ChartTranslationOptionList::ApplyThreshold() scoreThreshold += StaticData::Instance().GetTranslationOptionThreshold(); CollType::iterator bound = std::partition(m_collection.begin(), - m_collection.end(), + m_collection.begin()+m_size, ScoreThresholdPred(scoreThreshold)); - for (CollType::iterator p = bound; p != m_collection.end(); ++p) { - delete *p; - } - - m_collection.erase(bound, m_collection.end()); + m_size = std::distance(m_collection.begin(), bound); } } diff --git a/moses/src/ChartTranslationOptionList.h b/moses/src/ChartTranslationOptionList.h index d4a2f470b..75ef73665 100644 --- a/moses/src/ChartTranslationOptionList.h +++ b/moses/src/ChartTranslationOptionList.h @@ -1,5 +1,3 @@ -// $Id$ - /*********************************************************************** Moses - factored phrase-based language decoder Copyright (C) 2006 University of Edinburgh @@ -21,22 +19,38 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #pragma once -#include -#include -#include -#include #include "ChartTranslationOption.h" -#include "TargetPhrase.h" -#include "Util.h" -#include "TargetPhraseCollection.h" -#include "ObjectPool.h" +#include "StackVec.h" + +#include namespace Moses { + +class TargetPhraseCollection; +class WordsRange; + //! a list of target phrases that is trsnalated from the same source phrase class ChartTranslationOptionList { - friend std::ostream& operator<<(std::ostream&, const ChartTranslationOptionList&); + public: + ChartTranslationOptionList(size_t); + ~ChartTranslationOptionList(); + + const ChartTranslationOption &Get(size_t i) const { return *m_collection[i]; } + + //! number of translation options + size_t GetSize() const { return m_size; } + + void Add(const TargetPhraseCollection &, const StackVec &, + const WordsRange &); + + void Clear(); + void ShrinkToLimit(); + void ApplyThreshold(); + + private: + typedef std::vector CollType; struct ScoreThresholdPred { @@ -48,76 +62,10 @@ class ChartTranslationOptionList float m_thresholdScore; }; -protected: -#ifdef USE_HYPO_POOL - static ObjectPool s_objectPool; -#endif - typedef std::vector CollType; CollType m_collection; + size_t m_size; float m_scoreThreshold; - WordsRange m_range; - -public: - // iters - typedef CollType::iterator iterator; - typedef CollType::const_iterator const_iterator; - - iterator begin() { - return m_collection.begin(); - } - iterator end() { - return m_collection.end(); - } - const_iterator begin() const { - return m_collection.begin(); - } - const_iterator end() const { - return m_collection.end(); - } - -#ifdef USE_HYPO_POOL - void *operator new(size_t /* num_bytes */) { - void *ptr = s_objectPool.getPtr(); - return ptr; - } - - static void Delete(ChartTranslationOptionList *obj) { - s_objectPool.freeObject(obj); - } -#else - static void Delete(ChartTranslationOptionList *obj) { - delete obj; - } -#endif - - ChartTranslationOptionList(const WordsRange &range); - ~ChartTranslationOptionList(); - - const ChartTranslationOption &Get(size_t ind) const { - return *m_collection[ind]; - } - - //! number of target phrases in this collection - size_t GetSize() const { - return m_collection.size(); - } - //! wether collection has any phrases - bool IsEmpty() const { - return m_collection.empty(); - } - - void Add(const TargetPhraseCollection &targetPhraseCollection - , const StackVec &stackVec - , size_t ruleLimit); - void Add(ChartTranslationOption *transOpt); - - void CreateChartRules(size_t ruleLimit); - - const WordsRange &GetSourceRange() const { - return m_range; - } - - void ApplyThreshold(); + const size_t m_ruleLimit; }; } diff --git a/moses/src/Scope3Parser/Parser.cpp b/moses/src/Scope3Parser/Parser.cpp index cd36f4154..960aa263f 100644 --- a/moses/src/Scope3Parser/Parser.cpp +++ b/moses/src/Scope3Parser/Parser.cpp @@ -43,11 +43,10 @@ void Scope3Parser::GetChartRuleCollection( { const size_t start = range.GetStartPos(); const size_t end = range.GetEndPos(); - const size_t ruleLimit = StaticData::Instance().GetRuleLimit(); std::vector > &pairVec = m_ruleApplications[start][end-start+1]; - MatchCallback matchCB(ruleLimit, outColl); + MatchCallback matchCB(range, outColl); for (std::vector >::const_iterator p = pairVec.begin(); p != pairVec.end(); ++p) { const UTrieNode &ruleNode = *(p->first); const VarSpanNode &varSpanNode = *(p->second); diff --git a/moses/src/Scope3Parser/Parser.h b/moses/src/Scope3Parser/Parser.h index 0c97f87c4..0b2386469 100644 --- a/moses/src/Scope3Parser/Parser.h +++ b/moses/src/Scope3Parser/Parser.h @@ -65,16 +65,16 @@ class Scope3Parser : public ChartRuleLookupManager struct MatchCallback { public: - MatchCallback(size_t ruleLimit, + MatchCallback(const WordsRange &range, ChartTranslationOptionList &out) - : m_ruleLimit(ruleLimit) + : m_range(range) , m_out(out) , m_tpc(NULL) {} void operator()(const StackVec &stackVec) { - m_out.Add(*m_tpc, stackVec, m_ruleLimit); + m_out.Add(*m_tpc, stackVec, m_range); } - size_t m_ruleLimit; + const WordsRange &m_range; ChartTranslationOptionList &m_out; const TargetPhraseCollection *m_tpc; };