Move implementation details from the header to .cpp file.

Also add const to variables that we don't want to change.
This commit is contained in:
Tetsuo Kiso 2012-11-05 01:24:16 +09:00
parent cccfb9a0c9
commit 96f7b42eb9
2 changed files with 81 additions and 78 deletions

View File

@ -6,17 +6,61 @@
// Copyright 2012 __MyCompanyName__. All rights reserved. // Copyright 2012 __MyCompanyName__. All rights reserved.
// //
#include <iostream>
#include "SentenceLevelScorer.h" #include "SentenceLevelScorer.h"
#include <iostream>
#include <boost/spirit/home/support/detail/lexer/runtime_error.hpp>
using namespace std; using namespace std;
namespace MosesTuning namespace MosesTuning
{ {
/** The sentence level scores have already been calculated, just need to average them SentenceLevelScorer::SentenceLevelScorer(const string& name, const string& config)
and include the differences. Allows scores which are floats **/ : Scorer(name, config),
m_regularisationStrategy(REG_NONE),
m_regularisationWindow(0) {
Init();
}
SentenceLevelScorer::~SentenceLevelScorer() {}
void SentenceLevelScorer::Init() {
// Configure regularisation.
static string KEY_TYPE = "regtype";
static string KEY_WINDOW = "regwin";
static string KEY_CASE = "case";
static string TYPE_NONE = "none";
static string TYPE_AVERAGE = "average";
static string TYPE_MINIMUM = "min";
static string TRUE = "true";
static string FALSE = "false";
const string type = getConfig(KEY_TYPE, TYPE_NONE);
if (type == TYPE_NONE) {
m_regularisationStrategy = REG_NONE;
} else if (type == TYPE_AVERAGE) {
m_regularisationStrategy = REG_AVERAGE;
} else if (type == TYPE_MINIMUM) {
m_regularisationStrategy = REG_MINIMUM;
} else {
throw boost::lexer::runtime_error("Unknown scorer regularisation strategy: " + type);
}
cerr << "Using scorer regularisation strategy: " << type << endl;
const string window = getConfig(KEY_WINDOW, "0");
m_regularisationWindow = atoi(window.c_str());
cerr << "Using scorer regularisation window: " << m_regularisationWindow << endl;
const string preservecase = getConfig(KEY_CASE, TRUE);
if (preservecase == TRUE) {
m_enable_preserve_case = true;
} else if (preservecase == FALSE) {
m_enable_preserve_case = false;
}
cerr << "Using case preservation: " << m_enable_preserve_case << endl;
}
void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t& diffs, void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) statscores_t& scores)
{ {
@ -31,7 +75,7 @@ void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t&
if (candidates.size() == 0) { if (candidates.size() == 0) {
throw runtime_error("No candidates supplied"); throw runtime_error("No candidates supplied");
} }
int numCounts = m_score_data->get(0,candidates[0]).size(); const int numCounts = m_score_data->get(0,candidates[0]).size();
vector<float> totals(numCounts); vector<float> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) { for (size_t i = 0; i < candidates.size(); ++i) {
//cout << " i " << i << " candi " << candidates[i] ; //cout << " i " << i << " candi " << candidates[i] ;
@ -57,21 +101,21 @@ void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t&
totals[k] /= candidates.size(); totals[k] /= candidates.size();
//cout << "finaltotals = " << totals[k] << endl; //cout << "finaltotals = " << totals[k] << endl;
} }
scores.push_back(calculateScore(totals)); scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates); candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores //apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) { for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) { for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first; const size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second; const size_t nid = diffs[i][j].second;
//cout << "sid = " << sid << endl; //cout << "sid = " << sid << endl;
//cout << "nid = " << nid << endl; //cout << "nid = " << nid << endl;
size_t last_nid = last_candidates[sid]; const size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) { for (size_t k = 0; k < totals.size(); ++k) {
float diff = m_score_data->get(sid,nid).get(k) const float diff = m_score_data->get(sid,nid).get(k)
- m_score_data->get(sid,last_nid).get(k); - m_score_data->get(sid,last_nid).get(k);
//cout << "diff = " << diff << endl; //cout << "diff = " << diff << endl;
totals[k] += diff/candidates.size(); totals[k] += diff/candidates.size();
//cout << "totals = " << totals[k] << endl; //cout << "totals = " << totals[k] << endl;
@ -80,29 +124,28 @@ void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t&
} }
scores.push_back(calculateScore(totals)); scores.push_back(calculateScore(totals));
} }
//regularisation. This can either be none, or the min or average as described in //regularisation. This can either be none, or the min or average as described in
//Cer, Jurafsky and Manning at WMT08 //Cer, Jurafsky and Manning at WMT08
if (_regularisationStrategy == REG_NONE || _regularisationWindow <= 0) { if (m_regularisationStrategy == REG_NONE || m_regularisationWindow <= 0) {
//no regularisation //no regularisation
return; return;
} }
//window size specifies the +/- in each direction //window size specifies the +/- in each direction
statscores_t raw_scores(scores);//copy scores statscores_t raw_scores(scores);//copy scores
for (size_t i = 0; i < scores.size(); ++i) { for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0; size_t start = 0;
if (i >= _regularisationWindow) { if (i >= m_regularisationWindow) {
start = i - _regularisationWindow; start = i - m_regularisationWindow;
} }
size_t end = min(scores.size(), i + _regularisationWindow+1); const size_t end = min(scores.size(), i + m_regularisationWindow+1);
if (_regularisationStrategy == REG_AVERAGE) { if (m_regularisationStrategy == REG_AVERAGE) {
scores[i] = score_average(raw_scores,start,end); scores[i] = score_average(raw_scores, start, end);
} else { } else {
scores[i] = score_min(raw_scores,start,end); scores[i] = score_min(raw_scores, start, end);
} }
} }
} }
} }

View File

@ -12,77 +12,37 @@
#include "Scorer.h" #include "Scorer.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <vector>
#include <boost/spirit/home/support/detail/lexer/runtime_error.hpp>
namespace MosesTuning namespace MosesTuning
{ {
/** /**
* Abstract base class for scorers that work by using sentence level * Abstract base class for scorers that work by using sentence level
* statistics eg. permutation distance metrics **/ * statistics (e.g., permutation distance metrics). **/
class SentenceLevelScorer : public Scorer class SentenceLevelScorer : public Scorer
{ {
public: public:
SentenceLevelScorer(const std::string& name, const std::string& config): Scorer(name,config) { SentenceLevelScorer(const std::string& name, const std::string& config);
//configure regularisation ~SentenceLevelScorer();
static std::string KEY_TYPE = "regtype";
static std::string KEY_WINDOW = "regwin"; /** The sentence level scores have already been calculated, just need to average them
static std::string KEY_CASE = "case"; and include the differences. Allows scores which are floats. **/
static std::string TYPE_NONE = "none";
static std::string TYPE_AVERAGE = "average";
static std::string TYPE_MINIMUM = "min";
static std::string TRUE = "true";
static std::string FALSE = "false";
std::string type = getConfig(KEY_TYPE,TYPE_NONE);
if (type == TYPE_NONE) {
_regularisationStrategy = REG_NONE;
} else if (type == TYPE_AVERAGE) {
_regularisationStrategy = REG_AVERAGE;
} else if (type == TYPE_MINIMUM) {
_regularisationStrategy = REG_MINIMUM;
} else {
throw boost::lexer::runtime_error("Unknown scorer regularisation strategy: " + type);
}
std::cerr << "Using scorer regularisation strategy: " << type << std::endl;
std::string window = getConfig(KEY_WINDOW,"0");
_regularisationWindow = atoi(window.c_str());
std::cerr << "Using scorer regularisation window: " << _regularisationWindow << std::endl;
std::string preservecase = getConfig(KEY_CASE,TRUE);
if (preservecase == TRUE) {
m_enable_preserve_case = true;
} else if (preservecase == FALSE) {
m_enable_preserve_case = false;
}
std::cerr << "Using case preservation: " << m_enable_preserve_case << std::endl;
}
~SentenceLevelScorer() {};
virtual void score(const candidates_t& candidates, const diffs_t& diffs, virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores); statscores_t& scores);
//calculate the actual score // calculate the actual score *
virtual statscore_t calculateScore(const std::vector<statscore_t>& totals) { virtual statscore_t calculateScore(const std::vector<statscore_t>& totals) const {
return 0; return 0;
}; }
protected: protected:
// Set up regularisation parameters.
void Init();
//regularisation //regularisation
ScorerRegularisationStrategy _regularisationStrategy; ScorerRegularisationStrategy m_regularisationStrategy;
size_t _regularisationWindow; size_t m_regularisationWindow;
}; };
} }
#endif #endif