Refactor PerScorer and BleuScorer to remove common code

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1704 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
bhaddow 2008-05-15 14:48:11 +00:00
parent b0ee845d7e
commit f320cf5174
7 changed files with 84 additions and 117 deletions

View File

@ -147,7 +147,7 @@ void BleuScorer::prepareStats(int sid, const string& text, ScoreStats& entry) {
entry.set(stats_str);
}
float BleuScorer::bleu(const vector<int>& comps) {
float BleuScorer::calculateScore(const vector<int>& comps) {
float logbleu = 0.0;
for (int i = 0; i < LENGTH; ++i) {
if (comps[2*i] == 0) {
@ -164,55 +164,3 @@ float BleuScorer::bleu(const vector<int>& comps) {
return exp(logbleu);
}
void BleuScorer::score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores) {
if (!_scoreData) {
throw runtime_error("score data not loaded");
}
//calculate the score for the candidates
vector<int> comps(LENGTH*2+1);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = _scoreData->get(i,candidates[i]);
if (stats.size() != comps.size()) {
stringstream msg;
msg << "Bleu statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< comps.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < comps.size(); ++k) {
comps[k] += stats.get(k);
}
}
scores.push_back(bleu(comps));
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < comps.size(); ++k) {
int diff = _scoreData->get(sid,nid).get(k)
- _scoreData->get(sid,last_nid).get(k);
comps[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(bleu(comps));
}
}
/*
void BleuScorer::prepare(const vector<string>& referencefiles, const string& nbestfile) {
//processReferences(referencefiles, refcounts,reflengths,encodings);
//processNbest(nbestfile,refcounts,reflengths,encodings);
}*/

View File

@ -25,16 +25,15 @@ enum BleuReferenceLengthStrategy { AVERAGE, SHORTEST, CLOSEST };
/**
* Bleu scoring
**/
class BleuScorer: public Scorer {
class BleuScorer: public StatisticsBasedScorer {
public:
BleuScorer() : Scorer("BLEU"),_refLengthStrategy(SHORTEST) {}
BleuScorer() : StatisticsBasedScorer("BLEU"),_refLengthStrategy(SHORTEST) {}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(int sid, const string& text, ScoreStats& entry);
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores);
static const int LENGTH;
protected:
float calculateScore(const vector<int>& comps);
private:
//no copy
@ -81,7 +80,6 @@ class BleuScorer: public Scorer {
typedef vector<counts_t*> refcounts_t;
size_t countNgrams(const string& line, counts_t& counts, unsigned int n);
float bleu(const vector<int>& comps);
void dump_counts(counts_t& counts) {
for (counts_it i = counts.begin(); i != counts.end(); ++i) {

View File

@ -5,7 +5,8 @@ Data.o \
BleuScorer.o \
Point.o \
Optimizer.o \
PerScorer.o
PerScorer.o \
Scorer.o
ifndef DEBUG
CFLAGS=-O3 -DTRACE_ENABLE

View File

@ -57,7 +57,7 @@ void PerScorer::prepareStats(int sid, const string& text, ScoreStats& entry) {
entry.set(stats_str);
}
float PerScorer::per(const vector<int>& comps) const {
float PerScorer::calculateScore(const vector<int>& comps) {
float denom = comps[2];
float num = comps[0] - max(0,comps[1]-comps[2]);
if (denom == 0) {
@ -67,52 +67,3 @@ float PerScorer::per(const vector<int>& comps) const {
return num/denom;
}
}
void PerScorer::score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores) {
//calculate the PER
/* Implementation of position-independent word error rate. This is defined
* as 1 - (correct - max(0,output_length - ref_length)) / ref_length
* In fact, we ignore the " 1 - " so that it can be maximised.
*/
//TODO: This code is pretty much the same as bleu. Could factor it out into
//a common superclass
if (!_scoreData) {
throw runtime_error("score data not loaded");
}
//calculate the score for the candidates
vector<int> comps(3); //correct, output, ref
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = _scoreData->get(i,candidates[i]);
if (stats.size() != comps.size()) {
stringstream msg;
msg << "PER statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< comps.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < comps.size(); ++k) {
comps[k] += stats.get(k);
}
}
scores.push_back(per(comps));
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < comps.size(); ++k) {
int diff = _scoreData->get(sid,nid).get(k)
- _scoreData->get(sid,last_nid).get(k);
comps[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(per(comps));
}
}

View File

@ -23,17 +23,17 @@ using namespace std;
* as 1 - (correct - max(0,output_length - ref_length)) / ref_length
* In fact, we ignore the " 1 - " so that it can be maximised.
**/
class PerScorer: public Scorer {
class PerScorer: public StatisticsBasedScorer {
public:
PerScorer() : Scorer("PER") {}
PerScorer() : StatisticsBasedScorer("PER") {}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(int sid, const string& text, ScoreStats& entry);
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores);
protected:
virtual float calculateScore(const vector<int>& comps) ;
private:
float per(const vector<int>& comps) const;
//no copy
PerScorer(const PerScorer&);

52
mert/Scorer.cpp Normal file
View File

@ -0,0 +1,52 @@
#include "Scorer.h"
void StatisticsBasedScorer::score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores) {
if (!_scoreData) {
throw runtime_error("Score data not loaded");
}
//calculate the score for the candidates
if (_scoreData->size() == 0) {
throw runtime_error("Score data is empty");
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
}
int numCounts = _scoreData->get(0,candidates[0]).size();
vector<int> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = _scoreData->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
}
}
scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
int diff = _scoreData->get(sid,nid).get(k)
- _scoreData->get(sid,last_nid).get(k);
totals[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(calculateScore(totals));
}
}

View File

@ -21,7 +21,7 @@ class ScoreStats;
/**
* Superclass of all scorers and dummy implementation. In order to add a new
* scorer it should be sufficient to override prepareStats(), setReferenceFiles()
* and score()
* and score() (or calculateScore()).
**/
class Scorer {
@ -135,4 +135,21 @@ class Scorer {
};
/**
* Abstract base class for scorers that work by adding statistics across all
* outout sentences, then apply some formula, e.g. bleu, per. **/
class StatisticsBasedScorer : public Scorer {
public:
StatisticsBasedScorer(const string& name): Scorer(name) {}
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
scores_t& scores);
protected:
//calculate the actual score
virtual float calculateScore(const vector<int>& totals) = 0;
};
#endif //__SCORER_H