clean up comparison functions for Words and Phrases

This commit is contained in:
Hieu Hoang 2015-10-17 21:43:03 +01:00
parent 5754a46905
commit 2683b58b53
15 changed files with 29 additions and 76 deletions

View File

@ -78,15 +78,6 @@ public:
return m_bestScore;
}
bool operator<(const ChartCellLabel &other) const {
// m_coverage and m_label uniquely identify a ChartCellLabel, so don't
// need to compare m_stack.
if (m_coverage == other.m_coverage) {
return m_label < other.m_label;
}
return m_coverage < other.m_coverage;
}
private:
const WordsRange &m_coverage;
const Word &m_label;

View File

@ -40,11 +40,7 @@ bool BleuScoreState::operator==(const FFState& o) const
return true;
const BleuScoreState& other = static_cast<const BleuScoreState&>(o);
int c = m_words.Compare(other.m_words);
if (c == 0)
return true;
return false;
return m_words == other.m_words;
}
std::ostream& operator<<(std::ostream& out, const BleuScoreState& state)

View File

@ -50,8 +50,7 @@ bool ControlRecombinationState::operator==(const FFState& other) const
const ControlRecombinationState &otherFF = static_cast<const ControlRecombinationState&>(other);
if (m_ff.GetType() == SameOutput) {
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase);
return ret == 0;
return m_outputPhrase == otherFF.m_outputPhrase;
} else {
// compare hypo address. Won't be equal unless they're actually the same hypo
if (m_hypo == otherFF.m_hypo)

View File

@ -1,5 +1,4 @@
#include <vector>
#include <set>
#include "NieceTerminal.h"
#include "moses/ScoreComponentCollection.h"
#include "moses/TargetPhrase.h"
@ -45,7 +44,7 @@ void NieceTerminal::EvaluateWithSourceContext(const InputType &input
const Phrase *ruleSource = targetPhrase.GetRuleSource();
assert(ruleSource);
std::set<Word> terms;
boost::unordered_set<Word> terms;
for (size_t i = 0; i < ruleSource->GetSize(); ++i) {
const Word &word = ruleSource->GetWord(i);
if (!word.IsNonTerminal()) {
@ -81,9 +80,9 @@ void NieceTerminal::EvaluateWhenApplied(const ChartHypothesis &hypo,
bool NieceTerminal::ContainTerm(const InputType &input,
const WordsRange &ntRange,
const std::set<Word> &terms) const
const boost::unordered_set<Word> &terms) const
{
std::set<Word>::const_iterator iter;
boost::unordered_set<Word>::const_iterator iter;
for (size_t pos = ntRange.GetStartPos(); pos <= ntRange.GetEndPos(); ++pos) {
const Word &word = input.GetWord(pos);

View File

@ -1,6 +1,6 @@
#pragma once
#include <set>
#include <boost/unordered_set.hpp>
#include <string>
#include "StatelessFeatureFunction.h"
@ -46,7 +46,7 @@ protected:
bool m_hardConstraint;
bool ContainTerm(const InputType &input,
const WordsRange &ntRange,
const std::set<Word> &terms) const;
const boost::unordered_set<Word> &terms) const;
};
}

View File

@ -21,23 +21,23 @@ size_t TargetNgramState::hash() const
bool TargetNgramState::operator==(const FFState& other) const
{
const TargetNgramState& rhs = dynamic_cast<const TargetNgramState&>(other);
int result;
bool result;
if (m_words.size() == rhs.m_words.size()) {
for (size_t i = 0; i < m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]);
if (result != 0) return false;
result = m_words[i] == rhs.m_words[i];
if (!result) return false;
}
return true;
} else if (m_words.size() < rhs.m_words.size()) {
for (size_t i = 0; i < m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]);
if (result != 0) return false;
result = m_words[i] == rhs.m_words[i];
if (!result) return false;
}
return true;
} else {
for (size_t i = 0; i < rhs.m_words.size(); ++i) {
result = Word::Compare(m_words[i],rhs.m_words[i]);
if (result != 0) return false;
result = m_words[i] == rhs.m_words[i];
if (!result) return false;
}
return true;
}

View File

@ -182,14 +182,12 @@ public:
// prefix
if (m_startPos > 0) { // not for "<s> ..."
int ret = GetPrefix().Compare(other.GetPrefix());
if (ret != 0)
if (GetPrefix() != other.GetPrefix())
return false;
}
if (m_endPos < m_inputSize - 1) { // not for "... </s>"
int ret = GetSuffix().Compare(other.GetSuffix());
if (ret != 0)
if (GetSuffix() != other.GetSuffix())
return false;
}
return true;

View File

@ -23,9 +23,9 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#define moses_GenerationDictionary_h
#include <list>
#include <map>
#include <stdexcept>
#include <vector>
#include <boost/unordered_map.hpp>
#include "ScoreComponentCollection.h"
#include "Phrase.h"
#include "TypeDef.h"
@ -36,7 +36,7 @@ namespace Moses
class FactorCollection;
typedef std::map < Word , ScoreComponentCollection > OutputWordCollection;
typedef boost::unordered_map < Word , ScoreComponentCollection > OutputWordCollection;
// 1st = output phrase
// 2nd = log probability (score)

View File

@ -1,10 +1,10 @@
#pragma once
#include <set>
#include <vector>
#include <boost/shared_ptr.hpp>
#include <boost/unordered_map.hpp>
#include <boost/unordered_set.hpp>
#include "moses/InputType.h"
#include "moses/Syntax/KBestExtractor.h"

View File

@ -65,7 +65,7 @@ void Manager::OutputUnknowns(OutputCollector *collector) const
long translationId = m_source.GetTranslationId();
std::ostringstream out;
for (std::set<Moses::Word>::const_iterator p = m_oovs.begin();
for (boost::unordered_set<Moses::Word>::const_iterator p = m_oovs.begin();
p != m_oovs.end(); ++p) {
out << *p;
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <boost/unordered_set.hpp>
#include "moses/InputType.h"
#include "moses/BaseManager.h"
@ -50,7 +51,7 @@ public:
virtual const SHyperedge *GetBestSHyperedge() const = 0;
protected:
std::set<Word> m_oovs;
boost::unordered_set<Word> m_oovs;
private:
// Syntax-specific helper functions used to implement OutputNBest.

View File

@ -108,7 +108,7 @@ void Manager<Parser>::InitializeParsers(PChart &pchart,
// Find the set of OOVs for this input. This function assumes that the
// PChart argument has already been initialized from the input.
template<typename Parser>
void Manager<Parser>::FindOovs(const PChart &pchart, std::set<Word> &oovs,
void Manager<Parser>::FindOovs(const PChart &pchart, boost::unordered_set<Word> &oovs,
std::size_t maxOovWidth)
{
// Get the set of RuleTries.

View File

@ -45,7 +45,7 @@ public:
void OutputDetailedTranslationReport(OutputCollector *collector) const;
private:
void FindOovs(const PChart &, std::set<Word> &, std::size_t);
void FindOovs(const PChart &, boost::unordered_set<Word> &, std::size_t);
void InitializeCharts();

View File

@ -230,29 +230,6 @@ void swap(TargetPhrase &first, TargetPhrase &second);
std::ostream& operator<<(std::ostream&, const TargetPhrase&);
/**
* Hasher that looks at source and target phrase.
**/
struct TargetPhraseHasher {
inline size_t operator()(const TargetPhrase& targetPhrase) const {
size_t seed = 0;
boost::hash_combine(seed, targetPhrase);
boost::hash_combine(seed, targetPhrase.GetAlignTerm());
boost::hash_combine(seed, targetPhrase.GetAlignNonTerm());
return seed;
}
};
struct TargetPhraseComparator {
inline bool operator()(const TargetPhrase& lhs, const TargetPhrase& rhs) const {
return lhs.Compare(rhs) == 0 &&
lhs.GetAlignTerm() == rhs.GetAlignTerm() &&
lhs.GetAlignNonTerm() == rhs.GetAlignNonTerm();
}
};
}
#endif

View File

@ -116,14 +116,6 @@ public:
StringPiece GetString(FactorType factorType) const;
TO_STRING();
//! transitive comparison of Word objects
inline bool operator< (const Word &compare) const {
// needed to store word in GenerationDictionary map
// uses comparison of FactorKey
// 'proper' comparison, not address/id comparison
return Compare(*this, compare) < 0;
}
bool operator== (const Word &compare) const;
inline bool operator!= (const Word &compare) const {
@ -153,6 +145,11 @@ public:
}
};
inline size_t hash_value(const Word& word)
{
return word.hash();
}
struct WordComparer {
size_t operator()(const Word* word) const {
return word->hash();
@ -165,11 +162,6 @@ struct WordComparer {
};
inline size_t hash_value(const Word& word)
{
return word.hash();
}
}
#endif