From f320cf51749f26658533e7d5ed33fff41f05a0e0 Mon Sep 17 00:00:00 2001 From: bhaddow Date: Thu, 15 May 2008 14:48:11 +0000 Subject: [PATCH] Refactor PerScorer and BleuScorer to remove common code git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1704 1f5c12ca-751b-0410-a591-d2e778427230 --- mert/BleuScorer.cpp | 54 +-------------------------------------------- mert/BleuScorer.h | 12 +++++----- mert/Makefile | 3 ++- mert/PerScorer.cpp | 51 +----------------------------------------- mert/PerScorer.h | 10 ++++----- mert/Scorer.cpp | 52 +++++++++++++++++++++++++++++++++++++++++++ mert/Scorer.h | 19 +++++++++++++++- 7 files changed, 84 insertions(+), 117 deletions(-) create mode 100644 mert/Scorer.cpp diff --git a/mert/BleuScorer.cpp b/mert/BleuScorer.cpp index 3abb05c0f..f3cbdd320 100644 --- a/mert/BleuScorer.cpp +++ b/mert/BleuScorer.cpp @@ -147,7 +147,7 @@ void BleuScorer::prepareStats(int sid, const string& text, ScoreStats& entry) { entry.set(stats_str); } -float BleuScorer::bleu(const vector& comps) { +float BleuScorer::calculateScore(const vector& comps) { float logbleu = 0.0; for (int i = 0; i < LENGTH; ++i) { if (comps[2*i] == 0) { @@ -164,55 +164,3 @@ float BleuScorer::bleu(const vector& comps) { return exp(logbleu); } - - -void BleuScorer::score(const candidates_t& candidates, const diffs_t& diffs, - scores_t& scores) { - if (!_scoreData) { - throw runtime_error("score data not loaded"); - } - //calculate the score for the candidates - vector comps(LENGTH*2+1); - for (size_t i = 0; i < candidates.size(); ++i) { - ScoreStats stats = _scoreData->get(i,candidates[i]); - if (stats.size() != comps.size()) { - stringstream msg; - msg << "Bleu statistics for (" << "," << candidates[i] << ") have incorrect " - << "number of fields. Found: " << stats.size() << " Expected: " - << comps.size(); - throw runtime_error(msg.str()); - } - for (size_t k = 0; k < comps.size(); ++k) { - comps[k] += stats.get(k); - } - } - scores.push_back(bleu(comps)); - - candidates_t last_candidates(candidates); - //apply each of the diffs, and get new scores - for (size_t i = 0; i < diffs.size(); ++i) { - for (size_t j = 0; j < diffs[i].size(); ++j) { - size_t sid = diffs[i][j].first; - size_t nid = diffs[i][j].second; - size_t last_nid = last_candidates[sid]; - for (size_t k = 0; k < comps.size(); ++k) { - int diff = _scoreData->get(sid,nid).get(k) - - _scoreData->get(sid,last_nid).get(k); - comps[k] += diff; - } - last_candidates[sid] = nid; - } - scores.push_back(bleu(comps)); - } - -} - - -/* -void BleuScorer::prepare(const vector& referencefiles, const string& nbestfile) { - //processReferences(referencefiles, refcounts,reflengths,encodings); - //processNbest(nbestfile,refcounts,reflengths,encodings); -}*/ - - - diff --git a/mert/BleuScorer.h b/mert/BleuScorer.h index ade7559e5..ce541345a 100644 --- a/mert/BleuScorer.h +++ b/mert/BleuScorer.h @@ -25,16 +25,15 @@ enum BleuReferenceLengthStrategy { AVERAGE, SHORTEST, CLOSEST }; /** * Bleu scoring **/ -class BleuScorer: public Scorer { +class BleuScorer: public StatisticsBasedScorer { public: - BleuScorer() : Scorer("BLEU"),_refLengthStrategy(SHORTEST) {} + BleuScorer() : StatisticsBasedScorer("BLEU"),_refLengthStrategy(SHORTEST) {} virtual void setReferenceFiles(const vector& referenceFiles); virtual void prepareStats(int sid, const string& text, ScoreStats& entry); - - virtual void score(const candidates_t& candidates, const diffs_t& diffs, - scores_t& scores); - static const int LENGTH; + + protected: + float calculateScore(const vector& comps); private: //no copy @@ -81,7 +80,6 @@ class BleuScorer: public Scorer { typedef vector refcounts_t; size_t countNgrams(const string& line, counts_t& counts, unsigned int n); - float bleu(const vector& comps); void dump_counts(counts_t& counts) { for (counts_it i = counts.begin(); i != counts.end(); ++i) { diff --git a/mert/Makefile b/mert/Makefile index e364fdd0f..eb52da940 100755 --- a/mert/Makefile +++ b/mert/Makefile @@ -5,7 +5,8 @@ Data.o \ BleuScorer.o \ Point.o \ Optimizer.o \ -PerScorer.o +PerScorer.o \ +Scorer.o ifndef DEBUG CFLAGS=-O3 -DTRACE_ENABLE diff --git a/mert/PerScorer.cpp b/mert/PerScorer.cpp index 720f4de98..be5e29aef 100644 --- a/mert/PerScorer.cpp +++ b/mert/PerScorer.cpp @@ -57,7 +57,7 @@ void PerScorer::prepareStats(int sid, const string& text, ScoreStats& entry) { entry.set(stats_str); } -float PerScorer::per(const vector& comps) const { +float PerScorer::calculateScore(const vector& comps) { float denom = comps[2]; float num = comps[0] - max(0,comps[1]-comps[2]); if (denom == 0) { @@ -67,52 +67,3 @@ float PerScorer::per(const vector& comps) const { return num/denom; } } - -void PerScorer::score(const candidates_t& candidates, const diffs_t& diffs, - scores_t& scores) { - //calculate the PER - /* Implementation of position-independent word error rate. This is defined - * as 1 - (correct - max(0,output_length - ref_length)) / ref_length - * In fact, we ignore the " 1 - " so that it can be maximised. - */ - //TODO: This code is pretty much the same as bleu. Could factor it out into - //a common superclass - if (!_scoreData) { - throw runtime_error("score data not loaded"); - } - //calculate the score for the candidates - vector comps(3); //correct, output, ref - for (size_t i = 0; i < candidates.size(); ++i) { - ScoreStats stats = _scoreData->get(i,candidates[i]); - if (stats.size() != comps.size()) { - stringstream msg; - msg << "PER statistics for (" << "," << candidates[i] << ") have incorrect " - << "number of fields. Found: " << stats.size() << " Expected: " - << comps.size(); - throw runtime_error(msg.str()); - } - for (size_t k = 0; k < comps.size(); ++k) { - comps[k] += stats.get(k); - } - } - scores.push_back(per(comps)); - - candidates_t last_candidates(candidates); - //apply each of the diffs, and get new scores - for (size_t i = 0; i < diffs.size(); ++i) { - for (size_t j = 0; j < diffs[i].size(); ++j) { - size_t sid = diffs[i][j].first; - size_t nid = diffs[i][j].second; - size_t last_nid = last_candidates[sid]; - for (size_t k = 0; k < comps.size(); ++k) { - int diff = _scoreData->get(sid,nid).get(k) - - _scoreData->get(sid,last_nid).get(k); - comps[k] += diff; - } - last_candidates[sid] = nid; - } - scores.push_back(per(comps)); - } -} - - diff --git a/mert/PerScorer.h b/mert/PerScorer.h index 648b0d90c..868326638 100644 --- a/mert/PerScorer.h +++ b/mert/PerScorer.h @@ -23,17 +23,17 @@ using namespace std; * as 1 - (correct - max(0,output_length - ref_length)) / ref_length * In fact, we ignore the " 1 - " so that it can be maximised. **/ -class PerScorer: public Scorer { +class PerScorer: public StatisticsBasedScorer { public: - PerScorer() : Scorer("PER") {} + PerScorer() : StatisticsBasedScorer("PER") {} virtual void setReferenceFiles(const vector& referenceFiles); virtual void prepareStats(int sid, const string& text, ScoreStats& entry); - virtual void score(const candidates_t& candidates, const diffs_t& diffs, - scores_t& scores); + protected: + + virtual float calculateScore(const vector& comps) ; private: - float per(const vector& comps) const; //no copy PerScorer(const PerScorer&); diff --git a/mert/Scorer.cpp b/mert/Scorer.cpp new file mode 100644 index 000000000..329a401d8 --- /dev/null +++ b/mert/Scorer.cpp @@ -0,0 +1,52 @@ +#include "Scorer.h" + + + +void StatisticsBasedScorer::score(const candidates_t& candidates, const diffs_t& diffs, + scores_t& scores) { + if (!_scoreData) { + throw runtime_error("Score data not loaded"); + } + //calculate the score for the candidates + if (_scoreData->size() == 0) { + throw runtime_error("Score data is empty"); + } + if (candidates.size() == 0) { + throw runtime_error("No candidates supplied"); + } + int numCounts = _scoreData->get(0,candidates[0]).size(); + vector totals(numCounts); + for (size_t i = 0; i < candidates.size(); ++i) { + ScoreStats stats = _scoreData->get(i,candidates[i]); + if (stats.size() != totals.size()) { + stringstream msg; + msg << "Statistics for (" << "," << candidates[i] << ") have incorrect " + << "number of fields. Found: " << stats.size() << " Expected: " + << totals.size(); + throw runtime_error(msg.str()); + } + for (size_t k = 0; k < totals.size(); ++k) { + totals[k] += stats.get(k); + } + } + scores.push_back(calculateScore(totals)); + + candidates_t last_candidates(candidates); + //apply each of the diffs, and get new scores + for (size_t i = 0; i < diffs.size(); ++i) { + for (size_t j = 0; j < diffs[i].size(); ++j) { + size_t sid = diffs[i][j].first; + size_t nid = diffs[i][j].second; + size_t last_nid = last_candidates[sid]; + for (size_t k = 0; k < totals.size(); ++k) { + int diff = _scoreData->get(sid,nid).get(k) + - _scoreData->get(sid,last_nid).get(k); + totals[k] += diff; + } + last_candidates[sid] = nid; + } + scores.push_back(calculateScore(totals)); + } + +} + diff --git a/mert/Scorer.h b/mert/Scorer.h index cac435b6c..cde8a238f 100644 --- a/mert/Scorer.h +++ b/mert/Scorer.h @@ -21,7 +21,7 @@ class ScoreStats; /** * Superclass of all scorers and dummy implementation. In order to add a new * scorer it should be sufficient to override prepareStats(), setReferenceFiles() - * and score() + * and score() (or calculateScore()). **/ class Scorer { @@ -135,4 +135,21 @@ class Scorer { }; +/** + * Abstract base class for scorers that work by adding statistics across all + * outout sentences, then apply some formula, e.g. bleu, per. **/ +class StatisticsBasedScorer : public Scorer { + + public: + StatisticsBasedScorer(const string& name): Scorer(name) {} + virtual void score(const candidates_t& candidates, const diffs_t& diffs, + scores_t& scores); + + protected: + //calculate the actual score + virtual float calculateScore(const vector& totals) = 0; + +}; + + #endif //__SCORER_H