Create Reference class to clean up BleuScorer.

- Add an unit test for Reference.
- Move functions to calculate the reference length from
  BleuScorer to Reference.
This commit is contained in:
Tetsuo Kiso 2012-03-18 05:58:40 +09:00
parent 918bcafb80
commit 6b95a19eda
5 changed files with 236 additions and 82 deletions

View File

@ -7,6 +7,7 @@
#include <iostream>
#include <stdexcept>
#include "Ngram.h"
#include "Reference.h"
#include "Util.h"
namespace {
@ -19,7 +20,6 @@ const char REFLEN_CLOSEST[] = "closest";
} // namespace
BleuScorer::BleuScorer(const string& config)
: StatisticsBasedScorer("BLEU", config),
m_ref_length_type(CLOSEST) {
@ -60,9 +60,8 @@ size_t BleuScorer::countNgrams(const string& line, NgramCounts& counts,
void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
//make sure reference data is clear
m_ref_counts.reset();
m_ref_lengths.clear();
// Make sure reference data is clear
m_references.reset();
ClearEncoder();
//load reference data
@ -77,12 +76,10 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
while (getline(refin,line)) {
line = this->applyFactors(line);
if (i == 0) {
NgramCounts *counts = new NgramCounts; //these get leaked
m_ref_counts.push_back(counts);
vector<size_t> lengths;
m_ref_lengths.push_back(lengths);
Reference* ref = new Reference;
m_references.push_back(ref); // Take ownership of the Reference object.
}
if (m_ref_counts.size() <= sid) {
if (m_references.size() <= sid) {
throw runtime_error("File " + referenceFiles[i] + " has too many sentences");
}
NgramCounts counts;
@ -94,13 +91,13 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
const NgramCounts::Value newcount = ci->second;
NgramCounts::Value oldcount = 0;
m_ref_counts[sid]->lookup(ngram, &oldcount);
m_references[sid]->get_counts()->lookup(ngram, &oldcount);
if (newcount > oldcount) {
m_ref_counts[sid]->operator[](ngram) = newcount;
m_references[sid]->get_counts()->operator[](ngram) = newcount;
}
}
//add in the length
m_ref_lengths[sid].push_back(length);
m_references[sid]->push_back(length);
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
@ -112,7 +109,7 @@ void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
if (sid >= m_ref_counts.size()) {
if (sid >= m_references.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
@ -123,20 +120,8 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
string sentence = this->applyFactors(text);
const size_t length = countNgrams(sentence, testcounts, kBleuNgramOrder);
// Calculate effective reference length.
switch (m_ref_length_type) {
case SHORTEST:
CalcShortest(sid, stats);
break;
case AVERAGE:
CalcAverage(sid, stats);
break;
case CLOSEST:
CalcClosest(sid, length, stats);
break;
default:
throw runtime_error("Unsupported reflength strategy");
}
const int reference_len = CalcReferenceLength(sid, length);
stats.push_back(reference_len);
//precision on each ngram type
for (NgramCounts::const_iterator testcounts_it = testcounts.begin();
@ -146,7 +131,7 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
NgramCounts::Value correct = 0;
NgramCounts::Value v = 0;
if (m_ref_counts[sid]->lookup(testcounts_it->first, &v)) {
if (m_references[sid]->get_counts()->lookup(testcounts_it->first, &v)) {
correct = min(v, guess);
}
stats[len * 2 - 2] += correct;
@ -174,6 +159,23 @@ float BleuScorer::calculateScore(const vector<int>& comps) const
return exp(logbleu);
}
int BleuScorer::CalcReferenceLength(size_t sentence_id, size_t length) {
switch (m_ref_length_type) {
case AVERAGE:
return m_references[sentence_id]->CalcAverage();
break;
case CLOSEST:
return m_references[sentence_id]->CalcClosest(length);
break;
case SHORTEST:
return m_references[sentence_id]->CalcShortest();
break;
default:
cerr << "unknown reference types." << endl;
exit(1);
}
}
void BleuScorer::dump_counts(ostream* os,
const NgramCounts& counts) const {
for (NgramCounts::const_iterator it = counts.begin();
@ -191,44 +193,3 @@ void BleuScorer::dump_counts(ostream* os,
*os << endl;
}
void BleuScorer::CalcAverage(size_t sentence_id,
vector<ScoreStatsType>& stats) const {
int total = 0;
for (size_t i = 0;
i < m_ref_lengths[sentence_id].size(); ++i) {
total += m_ref_lengths[sentence_id][i];
}
const float mean = static_cast<float>(total) /
m_ref_lengths[sentence_id].size();
stats.push_back(static_cast<ScoreStatsType>(mean));
}
void BleuScorer::CalcClosest(size_t sentence_id,
size_t length,
vector<ScoreStatsType>& stats) const {
int min_diff = INT_MAX;
int min_idx = 0;
for (size_t i = 0; i < m_ref_lengths[sentence_id].size(); ++i) {
const int reflength = m_ref_lengths[sentence_id][i];
const int length_diff = abs(reflength - static_cast<int>(length));
// Look for the closest reference
if (length_diff < abs(min_diff)) {
min_diff = reflength - length;
min_idx = i;
// if two references has the same closest length, take the shortest
} else if (length_diff == abs(min_diff)) {
if (reflength < static_cast<int>(m_ref_lengths[sentence_id][min_idx])) {
min_idx = i;
}
}
}
stats.push_back(m_ref_lengths[sentence_id][min_idx]);
}
void BleuScorer::CalcShortest(size_t sentence_id,
vector<ScoreStatsType>& stats) const {
const int shortest = *min_element(m_ref_lengths[sentence_id].begin(),
m_ref_lengths[sentence_id].end());
stats.push_back(shortest);
}

View File

@ -15,6 +15,7 @@ using namespace std;
const int kBleuNgramOrder = 4;
class NgramCounts;
class Reference;
/**
* Bleu scoring
@ -30,6 +31,8 @@ public:
virtual float calculateScore(const vector<int>& comps) const;
virtual size_t NumberOfScores() const { return 2 * kBleuNgramOrder + 1; }
int CalcReferenceLength(size_t sentence_id, size_t length);
private:
enum ReferenceLengthType {
AVERAGE,
@ -44,19 +47,10 @@ private:
void dump_counts(std::ostream* os, const NgramCounts& counts) const;
// For calculating effective reference length.
void CalcAverage(size_t sentence_id,
vector<ScoreStatsType>& stats) const;
void CalcClosest(size_t sentence_id, size_t length,
vector<ScoreStatsType>& stats) const;
void CalcShortest(size_t sentence_id,
vector<ScoreStatsType>& stats) const;
ReferenceLengthType m_ref_length_type;
// data extracted from reference files
ScopedVector<NgramCounts> m_ref_counts;
vector<vector<size_t> > m_ref_lengths;
// reference translations.
ScopedVector<Reference> m_references;
// no copying allowed
BleuScorer(const BleuScorer&);

View File

@ -6,9 +6,13 @@ lib mert_lib :
Util.cpp
FileStream.cpp
Timer.cpp
ScoreStats.cpp ScoreArray.cpp ScoreData.cpp
ScoreStats.cpp
ScoreArray.cpp
ScoreData.cpp
ScoreDataIterator.cpp
FeatureStats.cpp FeatureArray.cpp FeatureData.cpp
FeatureStats.cpp
FeatureArray.cpp
FeatureData.cpp
FeatureDataIterator.cpp
Data.cpp
BleuScorer.cpp
@ -47,6 +51,7 @@ alias programs : mert extractor evaluator pro ;
unit-test feature_data_test : FeatureDataTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test data_test : DataTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test ngram_test : NgramTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test reference_test : ReferenceTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test timer_test : TimerTest.cpp mert_lib ..//boost_unit_test_framework ;
unit-test util_test : UtilTest.cpp mert_lib ..//boost_unit_test_framework ;

78
mert/Reference.h Normal file
View File

@ -0,0 +1,78 @@
#ifndef MERT_REFERENCE_H_
#define MERT_REFERENCE_H_
#include <algorithm>
#include <climits>
#include <iostream>
#include <vector>
#include "Ngram.h"
// Refernece class is a reference translation for an output translation.
class Reference {
public:
// for m_length
typedef std::vector<size_t>::iterator iterator;
typedef std::vector<size_t>::const_iterator const_iterator;
Reference() : m_counts(new NgramCounts) { }
~Reference() { delete m_counts; }
NgramCounts* get_counts() { return m_counts; }
const NgramCounts* get_counts() const { return m_counts; }
iterator begin() { return m_length.begin(); }
const_iterator begin() const { return m_length.begin(); }
iterator end() { return m_length.end(); }
const_iterator end() const { return m_length.end(); }
void push_back(size_t len) { m_length.push_back(len); }
size_t num_references() const { return m_length.size(); }
int CalcAverage() const;
int CalcClosest(size_t length) const;
int CalcShortest() const;
private:
NgramCounts* m_counts;
// multiple reference lengths
std::vector<size_t> m_length;
};
inline int Reference::CalcAverage() const {
int total = 0;
for (size_t i = 0; i < m_length.size(); ++i) {
total += m_length[i];
}
return static_cast<int>(
static_cast<float>(total) / m_length.size());
}
inline int Reference::CalcClosest(size_t length) const {
int min_diff = INT_MAX;
int closest_ref_id = 0; // an index of the closest reference translation
for (size_t i = 0; i < m_length.size(); ++i) {
const int ref_length = m_length[i];
const int length_diff = abs(ref_length - static_cast<int>(length));
const int abs_min_diff = abs(min_diff);
// Look for the closest reference
if (length_diff < abs_min_diff) {
min_diff = ref_length - length;
closest_ref_id = i;
// if two references has the same closest length, take the shortest
} else if (length_diff == abs_min_diff) {
if (ref_length < static_cast<int>(m_length[closest_ref_id])) {
closest_ref_id = i;
}
}
}
return static_cast<int>(m_length[closest_ref_id]);
}
inline int Reference::CalcShortest() const {
return *std::min_element(m_length.begin(), m_length.end());
}
#endif // MERT_REFERENCE_H_

116
mert/ReferenceTest.cpp Normal file
View File

@ -0,0 +1,116 @@
#include "Reference.h"
#define BOOST_TEST_MODULE MertReference
#include <boost/test/unit_test.hpp>
BOOST_AUTO_TEST_CASE(refernece_count) {
Reference ref;
BOOST_CHECK(ref.get_counts() != NULL);
}
BOOST_AUTO_TEST_CASE(refernece_length_iterator) {
Reference ref;
ref.push_back(4);
ref.push_back(2);
BOOST_REQUIRE(ref.num_references() == 2);
Reference::iterator it = ref.begin();
BOOST_CHECK_EQUAL(*it, 4);
++it;
BOOST_CHECK_EQUAL(*it, 2);
++it;
BOOST_CHECK(it == ref.end());
}
BOOST_AUTO_TEST_CASE(refernece_length_average) {
{
Reference ref;
ref.push_back(4);
ref.push_back(1);
BOOST_CHECK_EQUAL(2, ref.CalcAverage());
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
BOOST_CHECK_EQUAL(3, ref.CalcAverage());
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
ref.push_back(4);
ref.push_back(5);
BOOST_CHECK_EQUAL(4, ref.CalcAverage());
}
}
BOOST_AUTO_TEST_CASE(refernece_length_closest) {
{
Reference ref;
ref.push_back(4);
ref.push_back(1);
BOOST_REQUIRE(ref.num_references() == 2);
BOOST_CHECK_EQUAL(1, ref.CalcClosest(2));
BOOST_CHECK_EQUAL(1, ref.CalcClosest(1));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(3));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(5));
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
BOOST_REQUIRE(ref.num_references() == 2);
BOOST_CHECK_EQUAL(3, ref.CalcClosest(1));
BOOST_CHECK_EQUAL(3, ref.CalcClosest(2));
BOOST_CHECK_EQUAL(3, ref.CalcClosest(3));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(5));
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
ref.push_back(4);
ref.push_back(5);
BOOST_REQUIRE(ref.num_references() == 4);
BOOST_CHECK_EQUAL(3, ref.CalcClosest(1));
BOOST_CHECK_EQUAL(3, ref.CalcClosest(2));
BOOST_CHECK_EQUAL(3, ref.CalcClosest(3));
BOOST_CHECK_EQUAL(4, ref.CalcClosest(4));
BOOST_CHECK_EQUAL(5, ref.CalcClosest(5));
}
}
BOOST_AUTO_TEST_CASE(refernece_length_shortest) {
{
Reference ref;
ref.push_back(4);
ref.push_back(1);
BOOST_CHECK_EQUAL(1, ref.CalcShortest());
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
BOOST_CHECK_EQUAL(3, ref.CalcShortest());
}
{
Reference ref;
ref.push_back(4);
ref.push_back(3);
ref.push_back(4);
ref.push_back(5);
BOOST_CHECK_EQUAL(3, ref.CalcShortest());
}
}