Change the Encoder class to Vocabulary.

- Introduce the namespace to avoid naming collisions. The class name
  is used in KenLM.
- Add the unit test.
This commit is contained in:
Tetsuo Kiso 2012-03-20 03:43:04 +09:00
parent 2b28072f7a
commit 525f06452c
7 changed files with 114 additions and 65 deletions

View File

@ -64,7 +64,7 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
// Make sure reference data is clear
m_references.reset();
ClearEncoder();
ClearVocabulary();
//load reference data
for (size_t i = 0; i < referenceFiles.size(); ++i) {

View File

@ -56,5 +56,6 @@ unit-test optimizer_factory_test : OptimizerFactoryTest.cpp mert_lib ..//boost_u
unit-test reference_test : ReferenceTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test timer_test : TimerTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test util_test : UtilTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test vocabulary_test : VocabularyTest.cpp mert_lib ..//boost_unit_test_framework ;
install legacy : programs : <location>. ;

View File

@ -5,61 +5,6 @@
#include <map>
#include <string>
/**
* A map to manage vocaburaries.
*/
class Encoder {
public:
typedef std::map<std::string, int>::iterator iterator;
typedef std::map<std::string, int>::const_iterator const_iterator;
Encoder() {}
virtual ~Encoder() {}
/** Returns the assiged id for given "token". */
int Encode(const std::string& token) {
iterator it = m_vocab.find(token);
int encoded_token;
if (it == m_vocab.end()) {
// Add an new entry to the vocaburary.
encoded_token = static_cast<int>(m_vocab.size());
m_vocab[token] = encoded_token;
} else {
encoded_token = it->second;
}
return encoded_token;
}
/**
* Return true iff the specified "str" is found in the container.
*/
bool Lookup(const std::string&str , int* v) const {
const_iterator it = m_vocab.find(str);
if (it == m_vocab.end()) return false;
*v = it->second;
return true;
}
void clear() { m_vocab.clear(); }
bool empty() const { return m_vocab.empty(); }
size_t size() const { return m_vocab.size(); }
iterator find(const std::string& str) { return m_vocab.find(str); }
const_iterator find(const std::string& str) const { return m_vocab.find(str); }
int& operator[](const std::string& str) { return m_vocab[str]; }
iterator begin() { return m_vocab.begin(); }
const_iterator begin() const { return m_vocab.begin(); }
iterator end() { return m_vocab.end(); }
const_iterator end() const { return m_vocab.end(); }
private:
std::map<std::string, int> m_vocab;
};
/** A simple STL-std::map based n-gram counts. Basically, we provide
* typical accessors and mutaors, but we intentionally does not allow
* erasing elements.

View File

@ -1,10 +1,9 @@
#include "Scorer.h"
#include <limits>
#include "Ngram.h"
#include "Vocabulary.h"
#include "Util.h"
namespace {
//regularisation strategies
@ -37,14 +36,14 @@ inline float score_average(const statscores_t& scores, size_t start, size_t end)
Scorer::Scorer(const string& name, const string& config)
: m_name(name),
m_encoder(new Encoder),
m_vocab(new mert::Vocabulary),
m_score_data(0),
m_enable_preserve_case(true) {
InitConfig(config);
}
Scorer::~Scorer() {
delete m_encoder;
delete m_vocab;
}
void Scorer::InitConfig(const string& config) {
@ -78,11 +77,11 @@ void Scorer::TokenizeAndEncode(const string& line, vector<int>& encoded) {
*it = tolower(*it);
}
}
encoded.push_back(m_encoder->Encode(token));
encoded.push_back(m_vocab->Encode(token));
}
}
void Scorer::ClearEncoder() { m_encoder->clear(); }
void Scorer::ClearVocabulary() { m_vocab->clear(); }
/**
* Set the factors, which should be used for this metric

View File

@ -12,7 +12,12 @@
using namespace std;
class ScoreStats;
class Encoder;
namespace mert {
class Vocabulary;
} // namespace mert
/**
* Superclass of all scorers and dummy implementation.
@ -108,11 +113,13 @@ class Scorer
*/
virtual string applyFactors(const string& sentece) const;
mert::Vocabulary* GetVocab() const { return m_vocab; }
private:
void InitConfig(const string& config);
string m_name;
Encoder* m_encoder;
mert::Vocabulary* m_vocab;
map<string, string> m_config;
vector<int> m_factors;
@ -138,7 +145,7 @@ class Scorer
*/
void TokenizeAndEncode(const string& line, vector<int>& encoded);
void ClearEncoder();
void ClearVocabulary();
};
/**

69
mert/Vocabulary.h Normal file
View File

@ -0,0 +1,69 @@
#ifndef MERT_VOCABULARY_H_
#define MERT_VOCABULARY_H_
#include <map>
#include <string>
namespace mert {
/**
* A embarrassingly simple map to handle vocabularies to calculate
* various scores such as BLEU.
*
* TODO: replace this with more efficient data structure.
*/
class Vocabulary {
public:
typedef std::map<std::string, int>::iterator iterator;
typedef std::map<std::string, int>::const_iterator const_iterator;
Vocabulary() {}
virtual ~Vocabulary() {}
/** Returns the assiged id for given "token". */
int Encode(const std::string& token) {
iterator it = m_vocab.find(token);
int encoded_token;
if (it == m_vocab.end()) {
// Add an new entry to the vocaburary.
encoded_token = static_cast<int>(m_vocab.size());
m_vocab[token] = encoded_token;
} else {
encoded_token = it->second;
}
return encoded_token;
}
/**
* Return true iff the specified "str" is found in the container.
*/
bool Lookup(const std::string&str , int* v) const {
const_iterator it = m_vocab.find(str);
if (it == m_vocab.end()) return false;
*v = it->second;
return true;
}
void clear() { m_vocab.clear(); }
bool empty() const { return m_vocab.empty(); }
size_t size() const { return m_vocab.size(); }
iterator find(const std::string& str) { return m_vocab.find(str); }
const_iterator find(const std::string& str) const { return m_vocab.find(str); }
int& operator[](const std::string& str) { return m_vocab[str]; }
iterator begin() { return m_vocab.begin(); }
const_iterator begin() const { return m_vocab.begin(); }
iterator end() { return m_vocab.end(); }
const_iterator end() const { return m_vocab.end(); }
private:
std::map<std::string, int> m_vocab;
};
} // namespace mert
#endif // MERT_VOCABULARY_H_

28
mert/VocabularyTest.cpp Normal file
View File

@ -0,0 +1,28 @@
#include "Vocabulary.h"
#define BOOST_TEST_MODULE MertVocabulary
#include <boost/test/unit_test.hpp>
BOOST_AUTO_TEST_CASE(vocab_basic) {
mert::Vocabulary vocab;
BOOST_REQUIRE(vocab.empty());
vocab.clear();
BOOST_CHECK_EQUAL(0, vocab.Encode("hello"));
BOOST_CHECK_EQUAL(0, vocab.Encode("hello"));
BOOST_CHECK_EQUAL(1, vocab.Encode("world"));
BOOST_CHECK_EQUAL(2, vocab.size());
int v;
BOOST_CHECK(vocab.Lookup("hello", &v));
BOOST_CHECK_EQUAL(0, v);
BOOST_CHECK(vocab.Lookup("world", &v));
BOOST_CHECK_EQUAL(1, v);
BOOST_CHECK(!vocab.Lookup("java", &v));
vocab.clear();
BOOST_CHECK(!vocab.Lookup("hello", &v));
BOOST_CHECK(!vocab.Lookup("world", &v));
}