2008-05-27 20:50:52 +04:00
|
|
|
#include "BleuScorer.h"
|
2011-11-14 14:52:21 +04:00
|
|
|
|
|
|
|
#include <algorithm>
|
2012-03-19 17:45:15 +04:00
|
|
|
#include <cassert>
|
2011-11-14 10:15:30 +04:00
|
|
|
#include <cmath>
|
2011-11-14 14:52:21 +04:00
|
|
|
#include <climits>
|
2011-11-14 10:15:30 +04:00
|
|
|
#include <fstream>
|
2012-02-25 21:01:03 +04:00
|
|
|
#include <iostream>
|
2011-11-14 10:15:30 +04:00
|
|
|
#include <stdexcept>
|
2012-04-06 20:02:32 +04:00
|
|
|
|
|
|
|
#include "util/check.hh"
|
2012-03-14 17:14:11 +04:00
|
|
|
#include "Ngram.h"
|
2012-03-18 00:58:40 +04:00
|
|
|
#include "Reference.h"
|
2011-11-14 10:15:30 +04:00
|
|
|
#include "Util.h"
|
2012-04-29 10:11:30 +04:00
|
|
|
#include "ScoreDataIterator.h"
|
|
|
|
#include "FeatureDataIterator.h"
|
2012-03-20 00:49:10 +04:00
|
|
|
#include "Vocabulary.h"
|
2008-05-14 16:23:58 +04:00
|
|
|
|
2012-05-10 02:51:05 +04:00
|
|
|
using namespace std;
|
|
|
|
|
2012-02-25 20:18:08 +04:00
|
|
|
namespace {
|
|
|
|
|
|
|
|
// configure regularisation
|
|
|
|
const char KEY_REFLEN[] = "reflen";
|
|
|
|
const char REFLEN_AVERAGE[] = "average";
|
|
|
|
const char REFLEN_SHORTEST[] = "shortest";
|
|
|
|
const char REFLEN_CLOSEST[] = "closest";
|
|
|
|
|
|
|
|
} // namespace
|
2008-05-14 16:23:58 +04:00
|
|
|
|
2012-06-30 23:23:45 +04:00
|
|
|
namespace MosesTuning
|
|
|
|
{
|
|
|
|
|
|
|
|
|
2011-11-12 05:16:31 +04:00
|
|
|
BleuScorer::BleuScorer(const string& config)
|
2012-02-25 13:14:00 +04:00
|
|
|
: StatisticsBasedScorer("BLEU", config),
|
2012-02-01 15:24:48 +04:00
|
|
|
m_ref_length_type(CLOSEST) {
|
2012-02-25 20:18:08 +04:00
|
|
|
const string reflen = getConfig(KEY_REFLEN, REFLEN_CLOSEST);
|
2011-11-12 05:16:31 +04:00
|
|
|
if (reflen == REFLEN_AVERAGE) {
|
2012-02-01 15:24:48 +04:00
|
|
|
m_ref_length_type = AVERAGE;
|
2011-11-12 05:16:31 +04:00
|
|
|
} else if (reflen == REFLEN_SHORTEST) {
|
2012-02-01 15:24:48 +04:00
|
|
|
m_ref_length_type = SHORTEST;
|
2011-11-12 05:16:31 +04:00
|
|
|
} else if (reflen == REFLEN_CLOSEST) {
|
2012-02-01 15:24:48 +04:00
|
|
|
m_ref_length_type = CLOSEST;
|
2011-11-12 05:16:31 +04:00
|
|
|
} else {
|
|
|
|
throw runtime_error("Unknown reference length strategy: " + reflen);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
BleuScorer::~BleuScorer() {}
|
2008-05-14 17:36:55 +04:00
|
|
|
|
2012-03-19 17:45:15 +04:00
|
|
|
size_t BleuScorer::CountNgrams(const string& line, NgramCounts& counts,
|
2012-02-25 20:11:56 +04:00
|
|
|
unsigned int n)
|
2011-02-24 15:42:19 +03:00
|
|
|
{
|
2012-03-19 17:45:15 +04:00
|
|
|
assert(n > 0);
|
2011-02-24 15:42:19 +03:00
|
|
|
vector<int> encoded_tokens;
|
2012-02-01 16:19:25 +04:00
|
|
|
TokenizeAndEncode(line, encoded_tokens);
|
2011-02-24 15:42:19 +03:00
|
|
|
for (size_t k = 1; k <= n; ++k) {
|
|
|
|
//ngram order longer than sentence - no point
|
|
|
|
if (k > encoded_tokens.size()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < encoded_tokens.size()-k+1; ++i) {
|
|
|
|
vector<int> ngram;
|
|
|
|
for (size_t j = i; j < i+k && j < encoded_tokens.size(); ++j) {
|
|
|
|
ngram.push_back(encoded_tokens[j]);
|
|
|
|
}
|
2012-03-19 18:21:02 +04:00
|
|
|
counts.Add(ngram);
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return encoded_tokens.size();
|
2008-05-14 17:36:55 +04:00
|
|
|
}
|
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
|
|
|
|
{
|
2012-03-18 00:58:40 +04:00
|
|
|
// Make sure reference data is clear
|
|
|
|
m_references.reset();
|
2012-03-20 00:49:10 +04:00
|
|
|
mert::VocabularyFactory::GetVocabulary()->clear();
|
2008-05-14 17:36:55 +04:00
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
//load reference data
|
|
|
|
for (size_t i = 0; i < referenceFiles.size(); ++i) {
|
|
|
|
TRACE_ERR("Loading reference from " << referenceFiles[i] << endl);
|
2012-04-04 17:33:30 +04:00
|
|
|
|
|
|
|
if (!OpenReference(referenceFiles[i].c_str(), i)) {
|
|
|
|
throw runtime_error("Unable to open " + referenceFiles[i]);
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
2012-04-04 17:33:30 +04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
bool BleuScorer::OpenReference(const char* filename, size_t file_id) {
|
|
|
|
ifstream ifs(filename);
|
|
|
|
if (!ifs) {
|
|
|
|
cerr << "Cannot open " << filename << endl;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return OpenReferenceStream(&ifs, file_id);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool BleuScorer::OpenReferenceStream(istream* is, size_t file_id) {
|
|
|
|
if (is == NULL) return false;
|
|
|
|
|
|
|
|
string line;
|
|
|
|
size_t sid = 0;
|
|
|
|
while (getline(*is, line)) {
|
2012-05-09 21:21:41 +04:00
|
|
|
line = preprocessSentence(line);
|
2012-04-04 17:33:30 +04:00
|
|
|
if (file_id == 0) {
|
|
|
|
Reference* ref = new Reference;
|
|
|
|
m_references.push_back(ref); // Take ownership of the Reference object.
|
|
|
|
}
|
|
|
|
if (m_references.size() <= sid) {
|
|
|
|
cerr << "Reference " << file_id << "has too many sentences." << endl;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
NgramCounts counts;
|
|
|
|
size_t length = CountNgrams(line, counts, kBleuNgramOrder);
|
|
|
|
|
|
|
|
//for any counts larger than those already there, merge them in
|
|
|
|
for (NgramCounts::const_iterator ci = counts.begin(); ci != counts.end(); ++ci) {
|
|
|
|
const NgramCounts::Key& ngram = ci->first;
|
|
|
|
const NgramCounts::Value newcount = ci->second;
|
|
|
|
|
|
|
|
NgramCounts::Value oldcount = 0;
|
|
|
|
m_references[sid]->get_counts()->Lookup(ngram, &oldcount);
|
|
|
|
if (newcount > oldcount) {
|
|
|
|
m_references[sid]->get_counts()->operator[](ngram) = newcount;
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
|
|
|
}
|
2012-04-04 17:33:30 +04:00
|
|
|
//add in the length
|
|
|
|
m_references[sid]->push_back(length);
|
|
|
|
if (sid > 0 && sid % 100 == 0) {
|
|
|
|
TRACE_ERR(".");
|
|
|
|
}
|
|
|
|
++sid;
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
2012-04-04 17:33:30 +04:00
|
|
|
return true;
|
2008-05-14 17:36:55 +04:00
|
|
|
}
|
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
|
|
|
|
{
|
2012-03-18 00:58:40 +04:00
|
|
|
if (sid >= m_references.size()) {
|
2011-02-24 15:42:19 +03:00
|
|
|
stringstream msg;
|
|
|
|
msg << "Sentence id (" << sid << ") not found in reference set";
|
|
|
|
throw runtime_error(msg.str());
|
|
|
|
}
|
2012-02-25 20:11:56 +04:00
|
|
|
NgramCounts testcounts;
|
2012-02-25 20:41:17 +04:00
|
|
|
// stats for this line
|
2012-03-09 21:49:31 +04:00
|
|
|
vector<ScoreStatsType> stats(kBleuNgramOrder * 2);
|
2012-05-09 21:21:41 +04:00
|
|
|
string sentence = preprocessSentence(text);
|
2012-03-19 17:45:15 +04:00
|
|
|
const size_t length = CountNgrams(sentence, testcounts, kBleuNgramOrder);
|
2012-02-25 13:14:00 +04:00
|
|
|
|
2012-03-18 00:58:40 +04:00
|
|
|
const int reference_len = CalcReferenceLength(sid, length);
|
|
|
|
stats.push_back(reference_len);
|
2012-02-25 20:54:51 +04:00
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
//precision on each ngram type
|
2012-02-25 20:11:56 +04:00
|
|
|
for (NgramCounts::const_iterator testcounts_it = testcounts.begin();
|
2011-02-24 15:42:19 +03:00
|
|
|
testcounts_it != testcounts.end(); ++testcounts_it) {
|
2012-03-14 17:44:51 +04:00
|
|
|
const NgramCounts::Value guess = testcounts_it->second;
|
|
|
|
const size_t len = testcounts_it->first.size();
|
|
|
|
NgramCounts::Value correct = 0;
|
2012-03-14 17:41:29 +04:00
|
|
|
|
|
|
|
NgramCounts::Value v = 0;
|
2012-03-19 18:21:02 +04:00
|
|
|
if (m_references[sid]->get_counts()->Lookup(testcounts_it->first, &v)) {
|
2012-03-14 17:41:29 +04:00
|
|
|
correct = min(v, guess);
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
2012-03-14 17:44:51 +04:00
|
|
|
stats[len * 2 - 2] += correct;
|
|
|
|
stats[len * 2 - 1] += guess;
|
2011-02-24 15:42:19 +03:00
|
|
|
}
|
2012-02-25 20:41:17 +04:00
|
|
|
entry.set(stats);
|
2008-05-14 17:36:55 +04:00
|
|
|
}
|
|
|
|
|
2012-06-24 06:51:48 +04:00
|
|
|
statscore_t BleuScorer::calculateScore(const vector<int>& comps) const
|
2011-02-24 15:42:19 +03:00
|
|
|
{
|
2012-04-06 20:02:32 +04:00
|
|
|
CHECK(comps.size() == kBleuNgramOrder * 2 + 1);
|
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
float logbleu = 0.0;
|
2012-03-09 21:49:31 +04:00
|
|
|
for (int i = 0; i < kBleuNgramOrder; ++i) {
|
2011-02-24 15:42:19 +03:00
|
|
|
if (comps[2*i] == 0) {
|
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
logbleu += log(comps[2*i]) - log(comps[2*i+1]);
|
|
|
|
|
|
|
|
}
|
2012-03-09 21:49:31 +04:00
|
|
|
logbleu /= kBleuNgramOrder;
|
|
|
|
// reflength divided by test length
|
|
|
|
const float brevity = 1.0 - static_cast<float>(comps[kBleuNgramOrder * 2]) / comps[1];
|
2011-02-24 15:42:19 +03:00
|
|
|
if (brevity < 0.0) {
|
|
|
|
logbleu += brevity;
|
|
|
|
}
|
|
|
|
return exp(logbleu);
|
2008-05-14 23:47:34 +04:00
|
|
|
}
|
2011-11-12 05:21:08 +04:00
|
|
|
|
2012-03-18 00:58:40 +04:00
|
|
|
int BleuScorer::CalcReferenceLength(size_t sentence_id, size_t length) {
|
|
|
|
switch (m_ref_length_type) {
|
|
|
|
case AVERAGE:
|
|
|
|
return m_references[sentence_id]->CalcAverage();
|
|
|
|
break;
|
|
|
|
case CLOSEST:
|
|
|
|
return m_references[sentence_id]->CalcClosest(length);
|
|
|
|
break;
|
|
|
|
case SHORTEST:
|
|
|
|
return m_references[sentence_id]->CalcShortest();
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
cerr << "unknown reference types." << endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-03-19 17:45:15 +04:00
|
|
|
void BleuScorer::DumpCounts(ostream* os,
|
|
|
|
const NgramCounts& counts) const {
|
2012-02-25 21:01:03 +04:00
|
|
|
for (NgramCounts::const_iterator it = counts.begin();
|
|
|
|
it != counts.end(); ++it) {
|
|
|
|
*os << "(";
|
|
|
|
const NgramCounts::Key& keys = it->first;
|
|
|
|
for (size_t i = 0; i < keys.size(); ++i) {
|
|
|
|
if (i != 0) {
|
|
|
|
*os << " ";
|
|
|
|
}
|
|
|
|
*os << keys[i];
|
|
|
|
}
|
|
|
|
*os << ") : " << it->second << ", ";
|
2011-11-12 05:21:08 +04:00
|
|
|
}
|
2012-02-25 21:01:03 +04:00
|
|
|
*os << endl;
|
2011-11-12 05:21:08 +04:00
|
|
|
}
|
2012-02-25 20:54:51 +04:00
|
|
|
|
2012-04-06 20:02:32 +04:00
|
|
|
float sentenceLevelBleuPlusOne(const vector<float>& stats) {
|
|
|
|
CHECK(stats.size() == kBleuNgramOrder * 2 + 1);
|
|
|
|
|
|
|
|
float logbleu = 0.0;
|
|
|
|
for (int j = 0; j < kBleuNgramOrder; j++) {
|
|
|
|
logbleu += log(stats[2 * j] + 1.0) - log(stats[2 * j + 1] + 1.0);
|
|
|
|
}
|
|
|
|
logbleu /= kBleuNgramOrder;
|
|
|
|
const float brevity = 1.0 - stats[(kBleuNgramOrder * 2)] / stats[1];
|
|
|
|
|
2011-02-24 15:42:19 +03:00
|
|
|
if (brevity < 0.0) {
|
|
|
|
logbleu += brevity;
|
|
|
|
}
|
|
|
|
return exp(logbleu);
|
2008-05-14 23:47:34 +04:00
|
|
|
}
|
2011-11-12 05:21:08 +04:00
|
|
|
|
2012-05-29 21:38:57 +04:00
|
|
|
float sentenceLevelBackgroundBleu(const std::vector<float>& sent, const std::vector<float>& bg)
|
|
|
|
{
|
|
|
|
// Sum sent and background
|
|
|
|
std::vector<float> stats;
|
|
|
|
CHECK(sent.size()==bg.size());
|
|
|
|
CHECK(sent.size()==kBleuNgramOrder*2+1);
|
|
|
|
for(size_t i=0;i<sent.size();i++)
|
|
|
|
stats.push_back(sent[i]+bg[i]);
|
|
|
|
|
|
|
|
// Calculate BLEU
|
|
|
|
float logbleu = 0.0;
|
|
|
|
for (int j = 0; j < kBleuNgramOrder; j++) {
|
|
|
|
logbleu += log(stats[2 * j]) - log(stats[2 * j + 1]);
|
2011-11-12 05:21:08 +04:00
|
|
|
}
|
2012-05-29 21:38:57 +04:00
|
|
|
logbleu /= kBleuNgramOrder;
|
|
|
|
const float brevity = 1.0 - stats[(kBleuNgramOrder * 2)] / stats[1];
|
|
|
|
|
|
|
|
if (brevity < 0.0) {
|
|
|
|
logbleu += brevity;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Exponentiate and scale by reference length (as per Chiang et al 08)
|
|
|
|
return exp(logbleu) * stats[kBleuNgramOrder*2];
|
|
|
|
}
|
|
|
|
|
|
|
|
float unsmoothedBleu(const std::vector<float>& stats) {
|
|
|
|
CHECK(stats.size() == kBleuNgramOrder * 2 + 1);
|
|
|
|
|
|
|
|
float logbleu = 0.0;
|
|
|
|
for (int j = 0; j < kBleuNgramOrder; j++) {
|
|
|
|
logbleu += log(stats[2 * j]) - log(stats[2 * j + 1]);
|
|
|
|
}
|
|
|
|
logbleu /= kBleuNgramOrder;
|
|
|
|
const float brevity = 1.0 - stats[(kBleuNgramOrder * 2)] / stats[1];
|
|
|
|
|
|
|
|
if (brevity < 0.0) {
|
|
|
|
logbleu += brevity;
|
|
|
|
}
|
|
|
|
return exp(logbleu);
|
2011-11-12 05:21:08 +04:00
|
|
|
}
|
2012-04-29 10:11:30 +04:00
|
|
|
|
|
|
|
vector<float> BleuScorer::ScoreNbestList(string scoreFile, string featureFile) {
|
|
|
|
vector<string> scoreFiles;
|
|
|
|
vector<string> featureFiles;
|
|
|
|
scoreFiles.push_back(scoreFile);
|
|
|
|
featureFiles.push_back(featureFile);
|
|
|
|
|
|
|
|
vector<FeatureDataIterator> featureDataIters;
|
|
|
|
vector<ScoreDataIterator> scoreDataIters;
|
|
|
|
for (size_t i = 0; i < featureFiles.size(); ++i) {
|
|
|
|
featureDataIters.push_back(FeatureDataIterator(featureFiles[i]));
|
|
|
|
scoreDataIters.push_back(ScoreDataIterator(scoreFiles[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
vector<pair<size_t,size_t> > hypotheses;
|
|
|
|
if (featureDataIters[0] == FeatureDataIterator::end()) {
|
|
|
|
cerr << "Error: at the end of feature data iterator" << endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < featureFiles.size(); ++i) {
|
|
|
|
if (featureDataIters[i] == FeatureDataIterator::end()) {
|
|
|
|
cerr << "Error: Feature file " << i << " ended prematurely" << endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
if (scoreDataIters[i] == ScoreDataIterator::end()) {
|
|
|
|
cerr << "Error: Score file " << i << " ended prematurely" << endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
if (featureDataIters[i]->size() != scoreDataIters[i]->size()) {
|
|
|
|
cerr << "Error: features and scores have different size" << endl;
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
for (size_t j = 0; j < featureDataIters[i]->size(); ++j) {
|
|
|
|
hypotheses.push_back(pair<size_t,size_t>(i,j));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// score the nbest list
|
|
|
|
vector<float> bleuScores;
|
|
|
|
for (size_t i=0; i < hypotheses.size(); ++i) {
|
|
|
|
pair<size_t,size_t> translation = hypotheses[i];
|
|
|
|
float bleu = sentenceLevelBleuPlusOne(scoreDataIters[translation.first]->operator[](translation.second));
|
|
|
|
bleuScores.push_back(bleu);
|
|
|
|
}
|
|
|
|
return bleuScores;
|
|
|
|
}
|
|
|
|
|
|
|
|
float BleuScorer::sentenceLevelBleuPlusOne(const vector<float>& stats) {
|
|
|
|
float logbleu = 0.0;
|
|
|
|
const unsigned int bleu_order = 4;
|
|
|
|
for (unsigned int j=0; j<bleu_order; j++) {
|
|
|
|
//cerr << (stats.get(2*j)+1) << "/" << (stats.get(2*j+1)+1) << " ";
|
|
|
|
logbleu += log(stats[2*j]+1) - log(stats[2*j+1]+1);
|
|
|
|
}
|
|
|
|
logbleu /= bleu_order;
|
|
|
|
float brevity = 1.0 - (float)stats[(bleu_order*2)]/stats[1];
|
|
|
|
if (brevity < 0.0) {
|
|
|
|
logbleu += brevity;
|
|
|
|
}
|
|
|
|
//cerr << brevity << " -> " << exp(logbleu) << endl;
|
|
|
|
return exp(logbleu);
|
|
|
|
}
|
2012-06-30 23:23:45 +04:00
|
|
|
|
2012-07-17 16:36:50 +04:00
|
|
|
}
|