Use boost::unordered_map for faster lookup in ChartCell and ChartCellLabelSet.

This commit is contained in:
Phil Williams 2012-01-23 14:19:19 +00:00
parent 8462f24b72
commit e6a5bd9c20
5 changed files with 52 additions and 35 deletions

View File

@ -57,15 +57,6 @@ ChartCell::~ChartCell()
delete m_sourceWordLabel; delete m_sourceWordLabel;
} }
/** Get all hypotheses in the cell that have the specified constituent label */
const HypoList &ChartCell::GetSortedHypotheses(const Word &constituentLabel) const
{
std::map<Word, ChartHypothesisCollection>::const_iterator
iter = m_hypoColl.find(constituentLabel);
CHECK(iter != m_hypoColl.end());
return iter->second.GetSortedHypotheses();
}
/** Add the given hypothesis to the cell */ /** Add the given hypothesis to the cell */
bool ChartCell::AddHypothesis(ChartHypothesis *hypo) bool ChartCell::AddHypothesis(ChartHypothesis *hypo)
{ {
@ -76,7 +67,7 @@ bool ChartCell::AddHypothesis(ChartHypothesis *hypo)
/** Pruning */ /** Pruning */
void ChartCell::PruneToSize() void ChartCell::PruneToSize()
{ {
std::map<Word, ChartHypothesisCollection>::iterator iter; MapType::iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
ChartHypothesisCollection &coll = iter->second; ChartHypothesisCollection &coll = iter->second;
coll.PruneToSize(m_manager); coll.PruneToSize(m_manager);
@ -118,7 +109,7 @@ void ChartCell::SortHypotheses()
{ {
// sort each mini cells & fill up target lhs list // sort each mini cells & fill up target lhs list
CHECK(m_targetLabelSet.Empty()); CHECK(m_targetLabelSet.Empty());
std::map<Word, ChartHypothesisCollection>::iterator iter; MapType::iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
ChartHypothesisCollection &coll = iter->second; ChartHypothesisCollection &coll = iter->second;
m_targetLabelSet.AddConstituent(iter->first, coll); m_targetLabelSet.AddConstituent(iter->first, coll);
@ -132,7 +123,7 @@ const ChartHypothesis *ChartCell::GetBestHypothesis() const
const ChartHypothesis *ret = NULL; const ChartHypothesis *ret = NULL;
float bestScore = -std::numeric_limits<float>::infinity(); float bestScore = -std::numeric_limits<float>::infinity();
std::map<Word, ChartHypothesisCollection>::const_iterator iter; MapType::const_iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
const HypoList &sortedList = iter->second.GetSortedHypotheses(); const HypoList &sortedList = iter->second.GetSortedHypotheses();
CHECK(sortedList.size() > 0); CHECK(sortedList.size() > 0);
@ -152,7 +143,7 @@ void ChartCell::CleanupArcList()
// only necessary if n-best calculations are enabled // only necessary if n-best calculations are enabled
if (!m_nBestIsEnabled) return; if (!m_nBestIsEnabled) return;
std::map<Word, ChartHypothesisCollection>::iterator iter; MapType::iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
ChartHypothesisCollection &coll = iter->second; ChartHypothesisCollection &coll = iter->second;
coll.CleanupArcList(); coll.CleanupArcList();
@ -161,7 +152,7 @@ void ChartCell::CleanupArcList()
void ChartCell::OutputSizes(std::ostream &out) const void ChartCell::OutputSizes(std::ostream &out) const
{ {
std::map<Word, ChartHypothesisCollection>::const_iterator iter; MapType::const_iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
const Word &targetLHS = iter->first; const Word &targetLHS = iter->first;
const ChartHypothesisCollection &coll = iter->second; const ChartHypothesisCollection &coll = iter->second;
@ -173,7 +164,7 @@ void ChartCell::OutputSizes(std::ostream &out) const
size_t ChartCell::GetSize() const size_t ChartCell::GetSize() const
{ {
size_t ret = 0; size_t ret = 0;
std::map<Word, ChartHypothesisCollection>::const_iterator iter; MapType::const_iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) { for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
const ChartHypothesisCollection &coll = iter->second; const ChartHypothesisCollection &coll = iter->second;
@ -185,7 +176,7 @@ size_t ChartCell::GetSize() const
void ChartCell::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned, bool> &reachable) const void ChartCell::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned, bool> &reachable) const
{ {
std::map<Word, ChartHypothesisCollection>::const_iterator iterOutside; MapType::const_iterator iterOutside;
for (iterOutside = m_hypoColl.begin(); iterOutside != m_hypoColl.end(); ++iterOutside) { for (iterOutside = m_hypoColl.begin(); iterOutside != m_hypoColl.end(); ++iterOutside) {
const ChartHypothesisCollection &coll = iterOutside->second; const ChartHypothesisCollection &coll = iterOutside->second;
coll.GetSearchGraph(translationId, outputSearchGraphStream, reachable); coll.GetSearchGraph(translationId, outputSearchGraphStream, reachable);
@ -194,7 +185,7 @@ void ChartCell::GetSearchGraph(long translationId, std::ostream &outputSearchGra
std::ostream& operator<<(std::ostream &out, const ChartCell &cell) std::ostream& operator<<(std::ostream &out, const ChartCell &cell)
{ {
std::map<Word, ChartHypothesisCollection>::const_iterator iterOutside; ChartCell::MapType::const_iterator iterOutside;
for (iterOutside = cell.m_hypoColl.begin(); iterOutside != cell.m_hypoColl.end(); ++iterOutside) { for (iterOutside = cell.m_hypoColl.begin(); iterOutside != cell.m_hypoColl.end(); ++iterOutside) {
const Word &targetLHS = iterOutside->first; const Word &targetLHS = iterOutside->first;
cerr << targetLHS << ":" << endl; cerr << targetLHS << ":" << endl;

View File

@ -33,6 +33,10 @@
#include "RuleCube.h" #include "RuleCube.h"
#include "ChartCellLabelSet.h" #include "ChartCellLabelSet.h"
#include <boost/functional/hash.hpp>
#include <boost/unordered_map.hpp>
#include <boost/version.hpp>
namespace Moses namespace Moses
{ {
class ChartTranslationOptionList; class ChartTranslationOptionList;
@ -44,9 +48,18 @@ class ChartCell
{ {
friend std::ostream& operator<<(std::ostream&, const ChartCell&); friend std::ostream& operator<<(std::ostream&, const ChartCell&);
public: public:
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_map<Word,
ChartHypothesisCollection,
NonTerminalHasher,
NonTerminalEqualityPred
> MapType;
#else
typedef std::map<Word, ChartHypothesisCollection> MapType;
#endif
protected: protected:
std::map<Word, ChartHypothesisCollection> m_hypoColl; MapType m_hypoColl;
WordsRange m_coverage; WordsRange m_coverage;
@ -63,7 +76,13 @@ public:
void ProcessSentence(const ChartTranslationOptionList &transOptList void ProcessSentence(const ChartTranslationOptionList &transOptList
,const ChartCellCollection &allChartCells); ,const ChartCellCollection &allChartCells);
const HypoList &GetSortedHypotheses(const Word &constituentLabel) const; /** Get all hypotheses in the cell that have the specified constituent label */
const HypoList *GetSortedHypotheses(const Word &constituentLabel) const
{
MapType::const_iterator p = m_hypoColl.find(constituentLabel);
return (p == m_hypoColl.end()) ? NULL : &(p->second.GetSortedHypotheses());
}
bool AddHypothesis(ChartHypothesis *hypo); bool AddHypothesis(ChartHypothesis *hypo);
void SortHypotheses(); void SortHypotheses();

View File

@ -20,8 +20,11 @@
#pragma once #pragma once
#include "ChartCellLabel.h" #include "ChartCellLabel.h"
#include "NonTerminal.h"
#include <set> #include <boost/functional/hash.hpp>
#include <boost/unordered_map.hpp>
#include <boost/version.hpp>
namespace Moses namespace Moses
{ {
@ -31,41 +34,45 @@ class ChartHypothesisCollection;
class ChartCellLabelSet class ChartCellLabelSet
{ {
private: private:
typedef std::set<ChartCellLabel> SetType; #if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_map<Word, ChartCellLabel,
NonTerminalHasher, NonTerminalEqualityPred
> MapType;
#else
typedef std::map<Word, ChartCellLabel> MapType;
#endif
public: public:
typedef SetType::const_iterator const_iterator; typedef MapType::const_iterator const_iterator;
ChartCellLabelSet(const WordsRange &coverage) : m_coverage(coverage) {} ChartCellLabelSet(const WordsRange &coverage) : m_coverage(coverage) {}
const_iterator begin() const { return m_set.begin(); } const_iterator begin() const { return m_map.begin(); }
const_iterator end() const { return m_set.end(); } const_iterator end() const { return m_map.end(); }
void AddWord(const Word &w) void AddWord(const Word &w)
{ {
ChartCellLabel cellLabel(m_coverage, w); m_map.insert(std::make_pair(w, ChartCellLabel(m_coverage, w)));
m_set.insert(cellLabel);
} }
void AddConstituent(const Word &w, const ChartHypothesisCollection &stack) void AddConstituent(const Word &w, const ChartHypothesisCollection &stack)
{ {
ChartCellLabel cellLabel(m_coverage, w, &stack); m_map.insert(std::make_pair(w, ChartCellLabel(m_coverage, w, &stack)));
m_set.insert(cellLabel);
} }
bool Empty() const { return m_set.empty(); } bool Empty() const { return m_map.empty(); }
size_t GetSize() const { return m_set.size(); } size_t GetSize() const { return m_map.size(); }
const ChartCellLabel *Find(const Word &w) const const ChartCellLabel *Find(const Word &w) const
{ {
SetType::const_iterator p = m_set.find(ChartCellLabel(m_coverage, w)); MapType::const_iterator p = m_map.find(w);
return p == m_set.end() ? 0 : &(*p); return p == m_map.end() ? 0 : &(p->second);
} }
private: private:
const WordsRange &m_coverage; const WordsRange &m_coverage;
SetType m_set; MapType m_map;
}; };
} }

View File

@ -225,7 +225,7 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
ChartCellLabelSet::const_iterator q = targetNonTerms.begin(); ChartCellLabelSet::const_iterator q = targetNonTerms.begin();
ChartCellLabelSet::const_iterator tEnd = targetNonTerms.end(); ChartCellLabelSet::const_iterator tEnd = targetNonTerms.end();
for (; q != tEnd; ++q) { for (; q != tEnd; ++q) {
const ChartCellLabel &cellLabel = *q; const ChartCellLabel &cellLabel = q->second;
// try to match both source and target non-terminal // try to match both source and target non-terminal
const PhraseDictionaryNodeSCFG * child = const PhraseDictionaryNodeSCFG * child =

View File

@ -174,7 +174,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// go through each TARGET lhs // go through each TARGET lhs
ChartCellLabelSet::const_iterator iterChartNonTerm; ChartCellLabelSet::const_iterator iterChartNonTerm;
for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) { for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) {
const ChartCellLabel &cellLabel = *iterChartNonTerm; const ChartCellLabel &cellLabel = iterChartNonTerm->second;
//cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl; //cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl;