Move Encoder class from Scorer.h to Ngram.h.

To add unit tests.
This commit is contained in:
Tetsuo Kiso 2012-03-19 23:21:02 +09:00
parent f686e8771a
commit 2b28072f7a
5 changed files with 87 additions and 54 deletions

View File

@ -54,7 +54,7 @@ size_t BleuScorer::CountNgrams(const string& line, NgramCounts& counts,
for (size_t j = i; j < i+k && j < encoded_tokens.size(); ++j) {
ngram.push_back(encoded_tokens[j]);
}
counts.add(ngram);
counts.Add(ngram);
}
}
return encoded_tokens.size();
@ -93,7 +93,7 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
const NgramCounts::Value newcount = ci->second;
NgramCounts::Value oldcount = 0;
m_references[sid]->get_counts()->lookup(ngram, &oldcount);
m_references[sid]->get_counts()->Lookup(ngram, &oldcount);
if (newcount > oldcount) {
m_references[sid]->get_counts()->operator[](ngram) = newcount;
}
@ -133,7 +133,7 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
NgramCounts::Value correct = 0;
NgramCounts::Value v = 0;
if (m_references[sid]->get_counts()->lookup(testcounts_it->first, &v)) {
if (m_references[sid]->get_counts()->Lookup(testcounts_it->first, &v)) {
correct = min(v, guess);
}
stats[len * 2 - 2] += correct;

View File

@ -3,6 +3,62 @@
#include <vector>
#include <map>
#include <string>
/**
* A map to manage vocaburaries.
*/
class Encoder {
public:
typedef std::map<std::string, int>::iterator iterator;
typedef std::map<std::string, int>::const_iterator const_iterator;
Encoder() {}
virtual ~Encoder() {}
/** Returns the assiged id for given "token". */
int Encode(const std::string& token) {
iterator it = m_vocab.find(token);
int encoded_token;
if (it == m_vocab.end()) {
// Add an new entry to the vocaburary.
encoded_token = static_cast<int>(m_vocab.size());
m_vocab[token] = encoded_token;
} else {
encoded_token = it->second;
}
return encoded_token;
}
/**
* Return true iff the specified "str" is found in the container.
*/
bool Lookup(const std::string&str , int* v) const {
const_iterator it = m_vocab.find(str);
if (it == m_vocab.end()) return false;
*v = it->second;
return true;
}
void clear() { m_vocab.clear(); }
bool empty() const { return m_vocab.empty(); }
size_t size() const { return m_vocab.size(); }
iterator find(const std::string& str) { return m_vocab.find(str); }
const_iterator find(const std::string& str) const { return m_vocab.find(str); }
int& operator[](const std::string& str) { return m_vocab[str]; }
iterator begin() { return m_vocab.begin(); }
const_iterator begin() const { return m_vocab.begin(); }
iterator end() { return m_vocab.end(); }
const_iterator end() const { return m_vocab.end(); }
private:
std::map<std::string, int> m_vocab;
};
/** A simple STL-std::map based n-gram counts. Basically, we provide
* typical accessors and mutaors, but we intentionally does not allow
@ -40,7 +96,7 @@ class NgramCounts {
/**
* If the specified "ngram" is found, we add counts.
* If not, we insert the default count in the container. */
void add(const Key& ngram) {
void Add(const Key& ngram) {
const_iterator it = find(ngram);
if (it != end()) {
m_counts[ngram] = it->second + 1;
@ -49,6 +105,16 @@ class NgramCounts {
}
}
/**
* Return true iff the specified "ngram" is found in the container.
*/
bool Lookup(const Key& ngram, Value* v) const {
const_iterator it = m_counts.find(ngram);
if (it == m_counts.end()) return false;
*v = it->second;
return true;
}
/**
* Clear all elments in the container.
*/
@ -69,16 +135,6 @@ class NgramCounts {
// Note: This is mainly used by unit tests.
int get_default_count() const { return kDefaultCount; }
/**
* Return true iff the specified "ngram" is found in the container.
*/
bool lookup(const Key& ngram, Value* v) const {
const_iterator it = m_counts.find(ngram);
if (it == m_counts.end()) return false;
*v = it->second;
return true;
}
iterator find(const Key& ngram) { return m_counts.find(ngram); }
const_iterator find(const Key& ngram) const { return m_counts.find(ngram); }

View File

@ -9,7 +9,7 @@ BOOST_AUTO_TEST_CASE(ngram_basic) {
key.push_back(1);
key.push_back(2);
key.push_back(4);
counts.add(key);
counts.Add(key);
BOOST_REQUIRE(!counts.empty());
BOOST_CHECK_EQUAL(counts.size(), 1);
@ -23,26 +23,26 @@ BOOST_AUTO_TEST_CASE(ngram_basic) {
BOOST_CHECK_EQUAL(it->second, 1);
}
BOOST_AUTO_TEST_CASE(ngram_add) {
BOOST_AUTO_TEST_CASE(ngram_Add) {
NgramCounts counts;
NgramCounts::Key key;
key.push_back(1);
key.push_back(2);
counts.add(key);
counts.Add(key);
BOOST_REQUIRE(!counts.empty());
BOOST_CHECK_EQUAL(counts[key], counts.get_default_count());
NgramCounts::Key key2;
key2.push_back(1);
key2.push_back(2);
counts.add(key2);
counts.Add(key2);
BOOST_CHECK_EQUAL(counts.size(), 1);
BOOST_CHECK_EQUAL(counts[key], counts.get_default_count() + 1);
BOOST_CHECK_EQUAL(counts[key2], counts.get_default_count() + 1);
NgramCounts::Key key3;
key3.push_back(10);
counts.add(key3);
counts.Add(key3);
BOOST_CHECK_EQUAL(counts.size(), 2);
BOOST_CHECK_EQUAL(counts[key3], counts.get_default_count());
}
@ -53,11 +53,11 @@ BOOST_AUTO_TEST_CASE(ngram_lookup) {
key.push_back(1);
key.push_back(2);
key.push_back(4);
counts.add(key);
counts.Add(key);
{
NgramCounts::Value v;
BOOST_REQUIRE(counts.lookup(key, &v));
BOOST_REQUIRE(counts.Lookup(key, &v));
BOOST_CHECK_EQUAL(v, 1);
}
@ -70,7 +70,7 @@ BOOST_AUTO_TEST_CASE(ngram_lookup) {
// We only check the return value;
// we don't check the value of "v" because it makes sense
// to check the value when the specified ngram is found.
BOOST_REQUIRE(!counts.lookup(key2, &v));
BOOST_REQUIRE(!counts.Lookup(key2, &v));
}
// test after clear
@ -78,6 +78,6 @@ BOOST_AUTO_TEST_CASE(ngram_lookup) {
BOOST_CHECK(counts.empty());
{
NgramCounts::Value v;
BOOST_CHECK(!counts.lookup(key, &v));
BOOST_CHECK(!counts.Lookup(key, &v));
}
}

View File

@ -1,7 +1,10 @@
#include "Scorer.h"
#include <limits>
#include "Ngram.h"
#include "Util.h"
namespace {
//regularisation strategies
@ -65,23 +68,6 @@ void Scorer::InitConfig(const string& config) {
}
}
Scorer::Encoder::Encoder() {}
Scorer::Encoder::~Encoder() {}
int Scorer::Encoder::Encode(const string& token) {
map<string, int>::iterator it = m_vocab.find(token);
int encoded_token;
if (it == m_vocab.end()) {
// Add an new entry to the vocaburary.
encoded_token = static_cast<int>(m_vocab.size());
m_vocab[token] = encoded_token;
} else {
encoded_token = it->second;
}
return encoded_token;
}
void Scorer::TokenizeAndEncode(const string& line, vector<int>& encoded) {
std::istringstream in(line);
std::string token;
@ -96,6 +82,8 @@ void Scorer::TokenizeAndEncode(const string& line, vector<int>& encoded) {
}
}
void Scorer::ClearEncoder() { m_encoder->clear(); }
/**
* Set the factors, which should be used for this metric
*/

View File

@ -12,6 +12,7 @@
using namespace std;
class ScoreStats;
class Encoder;
/**
* Superclass of all scorers and dummy implementation.
@ -108,17 +109,6 @@ class Scorer
virtual string applyFactors(const string& sentece) const;
private:
class Encoder {
public:
Encoder();
virtual ~Encoder();
int Encode(const std::string& token);
void Clear() { m_vocab.clear(); }
private:
std::map<std::string, int> m_vocab;
};
void InitConfig(const string& config);
string m_name;
@ -144,14 +134,13 @@ class Scorer
/**
* Tokenise line and encode.
* Note: We assume that all tokens are separated by single spaces.
* Note: We assume that all tokens are separated by whitespaces.
*/
void TokenizeAndEncode(const string& line, vector<int>& encoded);
void ClearEncoder() { m_encoder->Clear(); }
void ClearEncoder();
};
/**
* Abstract base class for Scorers that work by adding statistics across all
* outout sentences, then apply some formula, e.g., BLEU, PER.