clean up comparison functions for Words and Phrases

This commit is contained in:
Hieu Hoang 2015-10-17 21:43:03 +01:00
parent 5754a46905
commit 2683b58b53
15 changed files with 29 additions and 76 deletions

View File

@ -78,15 +78,6 @@ public:
return m_bestScore; return m_bestScore;
} }
bool operator<(const ChartCellLabel &other) const {
// m_coverage and m_label uniquely identify a ChartCellLabel, so don't
// need to compare m_stack.
if (m_coverage == other.m_coverage) {
return m_label < other.m_label;
}
return m_coverage < other.m_coverage;
}
private: private:
const WordsRange &m_coverage; const WordsRange &m_coverage;
const Word &m_label; const Word &m_label;

View File

@ -40,11 +40,7 @@ bool BleuScoreState::operator==(const FFState& o) const
return true; return true;
const BleuScoreState& other = static_cast<const BleuScoreState&>(o); const BleuScoreState& other = static_cast<const BleuScoreState&>(o);
int c = m_words.Compare(other.m_words); return m_words == other.m_words;
if (c == 0)
return true;
return false;
} }
std::ostream& operator<<(std::ostream& out, const BleuScoreState& state) std::ostream& operator<<(std::ostream& out, const BleuScoreState& state)

View File

@ -50,8 +50,7 @@ bool ControlRecombinationState::operator==(const FFState& other) const
const ControlRecombinationState &otherFF = static_cast<const ControlRecombinationState&>(other); const ControlRecombinationState &otherFF = static_cast<const ControlRecombinationState&>(other);
if (m_ff.GetType() == SameOutput) { if (m_ff.GetType() == SameOutput) {
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase); return m_outputPhrase == otherFF.m_outputPhrase;
return ret == 0;
} else { } else {
// compare hypo address. Won't be equal unless they're actually the same hypo // compare hypo address. Won't be equal unless they're actually the same hypo
if (m_hypo == otherFF.m_hypo) if (m_hypo == otherFF.m_hypo)

View File

@ -1,5 +1,4 @@
#include <vector> #include <vector>
#include <set>
#include "NieceTerminal.h" #include "NieceTerminal.h"
#include "moses/ScoreComponentCollection.h" #include "moses/ScoreComponentCollection.h"
#include "moses/TargetPhrase.h" #include "moses/TargetPhrase.h"
@ -45,7 +44,7 @@ void NieceTerminal::EvaluateWithSourceContext(const InputType &input
const Phrase *ruleSource = targetPhrase.GetRuleSource(); const Phrase *ruleSource = targetPhrase.GetRuleSource();
assert(ruleSource); assert(ruleSource);
std::set<Word> terms; boost::unordered_set<Word> terms;
for (size_t i = 0; i < ruleSource->GetSize(); ++i) { for (size_t i = 0; i < ruleSource->GetSize(); ++i) {
const Word &word = ruleSource->GetWord(i); const Word &word = ruleSource->GetWord(i);
if (!word.IsNonTerminal()) { if (!word.IsNonTerminal()) {
@ -81,9 +80,9 @@ void NieceTerminal::EvaluateWhenApplied(const ChartHypothesis &hypo,
bool NieceTerminal::ContainTerm(const InputType &input, bool NieceTerminal::ContainTerm(const InputType &input,
const WordsRange &ntRange, const WordsRange &ntRange,
const std::set<Word> &terms) const const boost::unordered_set<Word> &terms) const
{ {
std::set<Word>::const_iterator iter; boost::unordered_set<Word>::const_iterator iter;
for (size_t pos = ntRange.GetStartPos(); pos <= ntRange.GetEndPos(); ++pos) { for (size_t pos = ntRange.GetStartPos(); pos <= ntRange.GetEndPos(); ++pos) {
const Word &word = input.GetWord(pos); const Word &word = input.GetWord(pos);

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <set> #include <boost/unordered_set.hpp>
#include <string> #include <string>
#include "StatelessFeatureFunction.h" #include "StatelessFeatureFunction.h"
@ -46,7 +46,7 @@ protected:
bool m_hardConstraint; bool m_hardConstraint;
bool ContainTerm(const InputType &input, bool ContainTerm(const InputType &input,
const WordsRange &ntRange, const WordsRange &ntRange,
const std::set<Word> &terms) const; const boost::unordered_set<Word> &terms) const;
}; };
} }

View File

@ -21,23 +21,23 @@ size_t TargetNgramState::hash() const
bool TargetNgramState::operator==(const FFState& other) const bool TargetNgramState::operator==(const FFState& other) const
{ {
const TargetNgramState& rhs = dynamic_cast<const TargetNgramState&>(other); const TargetNgramState& rhs = dynamic_cast<const TargetNgramState&>(other);
int result; bool result;
if (m_words.size() == rhs.m_words.size()) { if (m_words.size() == rhs.m_words.size()) {
for (size_t i = 0; i < m_words.size(); ++i) { for (size_t i = 0; i < m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]); result = m_words[i] == rhs.m_words[i];
if (result != 0) return false; if (!result) return false;
} }
return true; return true;
} else if (m_words.size() < rhs.m_words.size()) { } else if (m_words.size() < rhs.m_words.size()) {
for (size_t i = 0; i < m_words.size(); ++i) { for (size_t i = 0; i < m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]); result = m_words[i] == rhs.m_words[i];
if (result != 0) return false; if (!result) return false;
} }
return true; return true;
} else { } else {
for (size_t i = 0; i < rhs.m_words.size(); ++i) { for (size_t i = 0; i < rhs.m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]); result = m_words[i] == rhs.m_words[i];
if (result != 0) return false; if (!result) return false;
} }
return true; return true;
} }

View File

@ -182,14 +182,12 @@ public:
// prefix // prefix
if (m_startPos > 0) { // not for "<s> ..." if (m_startPos > 0) { // not for "<s> ..."
int ret = GetPrefix().Compare(other.GetPrefix()); if (GetPrefix() != other.GetPrefix())
if (ret != 0)
return false; return false;
} }
if (m_endPos < m_inputSize - 1) { // not for "... </s>" if (m_endPos < m_inputSize - 1) { // not for "... </s>"
int ret = GetSuffix().Compare(other.GetSuffix()); if (GetSuffix() != other.GetSuffix())
if (ret != 0)
return false; return false;
} }
return true; return true;

View File

@ -23,9 +23,9 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#define moses_GenerationDictionary_h #define moses_GenerationDictionary_h
#include <list> #include <list>
#include <map>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include <boost/unordered_map.hpp>
#include "ScoreComponentCollection.h" #include "ScoreComponentCollection.h"
#include "Phrase.h" #include "Phrase.h"
#include "TypeDef.h" #include "TypeDef.h"
@ -36,7 +36,7 @@ namespace Moses
class FactorCollection; class FactorCollection;
typedef std::map < Word , ScoreComponentCollection > OutputWordCollection; typedef boost::unordered_map < Word , ScoreComponentCollection > OutputWordCollection;
// 1st = output phrase // 1st = output phrase
// 2nd = log probability (score) // 2nd = log probability (score)

View File

@ -1,10 +1,10 @@
#pragma once #pragma once
#include <set>
#include <vector> #include <vector>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/unordered_map.hpp> #include <boost/unordered_map.hpp>
#include <boost/unordered_set.hpp>
#include "moses/InputType.h" #include "moses/InputType.h"
#include "moses/Syntax/KBestExtractor.h" #include "moses/Syntax/KBestExtractor.h"

View File

@ -65,7 +65,7 @@ void Manager::OutputUnknowns(OutputCollector *collector) const
long translationId = m_source.GetTranslationId(); long translationId = m_source.GetTranslationId();
std::ostringstream out; std::ostringstream out;
for (std::set<Moses::Word>::const_iterator p = m_oovs.begin(); for (boost::unordered_set<Moses::Word>::const_iterator p = m_oovs.begin();
p != m_oovs.end(); ++p) { p != m_oovs.end(); ++p) {
out << *p; out << *p;
} }

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <boost/unordered_set.hpp>
#include "moses/InputType.h" #include "moses/InputType.h"
#include "moses/BaseManager.h" #include "moses/BaseManager.h"
@ -50,7 +51,7 @@ public:
virtual const SHyperedge *GetBestSHyperedge() const = 0; virtual const SHyperedge *GetBestSHyperedge() const = 0;
protected: protected:
std::set<Word> m_oovs; boost::unordered_set<Word> m_oovs;
private: private:
// Syntax-specific helper functions used to implement OutputNBest. // Syntax-specific helper functions used to implement OutputNBest.

View File

@ -108,7 +108,7 @@ void Manager<Parser>::InitializeParsers(PChart &pchart,
// Find the set of OOVs for this input. This function assumes that the // Find the set of OOVs for this input. This function assumes that the
// PChart argument has already been initialized from the input. // PChart argument has already been initialized from the input.
template<typename Parser> template<typename Parser>
void Manager<Parser>::FindOovs(const PChart &pchart, std::set<Word> &oovs, void Manager<Parser>::FindOovs(const PChart &pchart, boost::unordered_set<Word> &oovs,
std::size_t maxOovWidth) std::size_t maxOovWidth)
{ {
// Get the set of RuleTries. // Get the set of RuleTries.

View File

@ -45,7 +45,7 @@ public:
void OutputDetailedTranslationReport(OutputCollector *collector) const; void OutputDetailedTranslationReport(OutputCollector *collector) const;
private: private:
void FindOovs(const PChart &, std::set<Word> &, std::size_t); void FindOovs(const PChart &, boost::unordered_set<Word> &, std::size_t);
void InitializeCharts(); void InitializeCharts();

View File

@ -230,29 +230,6 @@ void swap(TargetPhrase &first, TargetPhrase &second);
std::ostream& operator<<(std::ostream&, const TargetPhrase&); std::ostream& operator<<(std::ostream&, const TargetPhrase&);
/**
* Hasher that looks at source and target phrase.
**/
struct TargetPhraseHasher {
inline size_t operator()(const TargetPhrase& targetPhrase) const {
size_t seed = 0;
boost::hash_combine(seed, targetPhrase);
boost::hash_combine(seed, targetPhrase.GetAlignTerm());
boost::hash_combine(seed, targetPhrase.GetAlignNonTerm());
return seed;
}
};
struct TargetPhraseComparator {
inline bool operator()(const TargetPhrase& lhs, const TargetPhrase& rhs) const {
return lhs.Compare(rhs) == 0 &&
lhs.GetAlignTerm() == rhs.GetAlignTerm() &&
lhs.GetAlignNonTerm() == rhs.GetAlignNonTerm();
}
};
} }
#endif #endif

View File

@ -116,14 +116,6 @@ public:
StringPiece GetString(FactorType factorType) const; StringPiece GetString(FactorType factorType) const;
TO_STRING(); TO_STRING();
//! transitive comparison of Word objects
inline bool operator< (const Word &compare) const {
// needed to store word in GenerationDictionary map
// uses comparison of FactorKey
// 'proper' comparison, not address/id comparison
return Compare(*this, compare) < 0;
}
bool operator== (const Word &compare) const; bool operator== (const Word &compare) const;
inline bool operator!= (const Word &compare) const { inline bool operator!= (const Word &compare) const {
@ -153,6 +145,11 @@ public:
} }
}; };
inline size_t hash_value(const Word& word)
{
return word.hash();
}
struct WordComparer { struct WordComparer {
size_t operator()(const Word* word) const { size_t operator()(const Word* word) const {
return word->hash(); return word->hash();
@ -165,11 +162,6 @@ struct WordComparer {
}; };
inline size_t hash_value(const Word& word)
{
return word.hash();
}
} }
#endif #endif