Add test cases for BLEU and sentence-level BLEU+1.

- Move a definition of sentenceLevelBleuPlusOne() from pro.cpp
  to BleuScorer.cpp.
- Add check for the length of an input vector.
This commit is contained in:
Tetsuo Kiso 2012-04-07 01:02:32 +09:00
parent bcc1958d94
commit d034eeb703
4 changed files with 91 additions and 16 deletions

View File

@ -7,6 +7,8 @@
#include <fstream>
#include <iostream>
#include <stdexcept>
#include "util/check.hh"
#include "Ngram.h"
#include "Reference.h"
#include "Util.h"
@ -160,6 +162,8 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
float BleuScorer::calculateScore(const vector<int>& comps) const
{
CHECK(comps.size() == kBleuNgramOrder * 2 + 1);
float logbleu = 0.0;
for (int i = 0; i < kBleuNgramOrder; ++i) {
if (comps[2*i] == 0) {
@ -211,3 +215,18 @@ void BleuScorer::DumpCounts(ostream* os,
*os << endl;
}
float sentenceLevelBleuPlusOne(const vector<float>& stats) {
CHECK(stats.size() == kBleuNgramOrder * 2 + 1);
float logbleu = 0.0;
for (int j = 0; j < kBleuNgramOrder; j++) {
logbleu += log(stats[2 * j] + 1.0) - log(stats[2 * j + 1] + 1.0);
}
logbleu /= kBleuNgramOrder;
const float brevity = 1.0 - stats[(kBleuNgramOrder * 2)] / stats[1];
if (brevity < 0.0) {
logbleu += brevity;
}
return exp(logbleu);
}

View File

@ -67,4 +67,9 @@ private:
BleuScorer& operator=(const BleuScorer&);
};
/** Computes sentence-level BLEU+1 score.
* This function is used in PRO.
*/
float sentenceLevelBleuPlusOne(const vector<float>& stats);
#endif // MERT_BLEU_SCORER_H_

View File

@ -3,6 +3,7 @@
#define BOOST_TEST_MODULE MertBleuScorer
#include <boost/test/unit_test.hpp>
#include <cmath>
#include "Ngram.h"
#include "Vocabulary.h"
#include "Util.h"
@ -110,6 +111,19 @@ void SetUpReferences(BleuScorer& scorer) {
}
}
const float kEPS = 0.0001f;
template <typename T>
bool IsAlmostEqual(T expected, T actual) {
if (abs(expected - actual) < kEPS) {
return true;
} else {
cerr << "Fail: expected = " << expected
<< " (actual = " << actual << ")" << endl;
return false;
}
}
} // namespace
BOOST_AUTO_TEST_CASE(bleu_reference_type) {
@ -204,3 +218,56 @@ BOOST_AUTO_TEST_CASE(bleu_clipped_counts) {
BOOST_CHECK_EQUAL(entry.get(5), 4); // trigram
BOOST_CHECK_EQUAL(entry.get(7), 3); // fourgram
}
BOOST_AUTO_TEST_CASE(calculate_actual_score) {
BOOST_REQUIRE(4 == kBleuNgramOrder);
vector<int> stats(2 * kBleuNgramOrder + 1);
BleuScorer scorer;
// unigram
stats[0] = 6;
stats[1] = 6;
// bigram
stats[2] = 4;
stats[3] = 5;
// trigram
stats[4] = 2;
stats[5] = 4;
// fourgram
stats[6] = 1;
stats[7] = 3;
// reference-length
stats[8] = 7;
BOOST_CHECK(IsAlmostEqual(0.5115f, scorer.calculateScore(stats)));
}
BOOST_AUTO_TEST_CASE(sentence_level_bleu) {
BOOST_REQUIRE(4 == kBleuNgramOrder);
vector<float> stats(2 * kBleuNgramOrder + 1);
// unigram
stats[0] = 6.0;
stats[1] = 6.0;
// bigram
stats[2] = 4.0;
stats[3] = 5.0;
// trigram
stats[4] = 2.0;
stats[5] = 4.0;
// fourgram
stats[6] = 1.0;
stats[7] = 3.0;
// reference-length
stats[8] = 7.0;
BOOST_CHECK(IsAlmostEqual(0.5985f, sentenceLevelBleuPlusOne(stats)));
}

View File

@ -70,22 +70,6 @@ public:
const pair<size_t,size_t>& getTranslation2() const { return m_translation2; }
};
static float sentenceLevelBleuPlusOne(const vector<float>& stats) {
float logbleu = 0.0;
for (int j = 0; j < kBleuNgramOrder; j++) {
//cerr << (stats.get(2*j)+1) << "/" << (stats.get(2*j+1)+1) << " ";
logbleu += log(stats[2*j]+1) - log(stats[2*j+1]+1);
}
logbleu /= kBleuNgramOrder;
const float brevity = 1.0 - static_cast<float>(stats[(kBleuNgramOrder * 2)]) / stats[1];
if (brevity < 0.0) {
logbleu += brevity;
}
//cerr << brevity << " -> " << exp(logbleu) << endl;
return exp(logbleu);
}
static void outputSample(ostream& out, const FeatureDataItem& f1, const FeatureDataItem& f2) {
// difference in score in regular features
for(unsigned int j=0; j<f1.dense.size(); j++)