merge Lexi Birch's LRScore from mert_mtm5 branch. Compiles and run. Hack, must double check with barry or lexi

This commit is contained in:
Hieu Hoang 2012-06-23 22:51:48 -04:00
parent f48c348508
commit 0cb63edcb9
14 changed files with 192 additions and 161 deletions

View File

@ -105,6 +105,8 @@
1E3962211594CFF9006FE978 /* Permutation.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E39621F1594CFF9006FE978 /* Permutation.h */; };
1E3962231594D0FF006FE978 /* SentenceLevelScorer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */; };
1E3962251594D12C006FE978 /* SentenceLevelScorer.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E3962241594D12C006FE978 /* SentenceLevelScorer.h */; };
1EE52B561596B3E4006DC938 /* StatisticsBasedScorer.h in Headers */ = {isa = PBXBuildFile; fileRef = 1EE52B551596B3E4006DC938 /* StatisticsBasedScorer.h */; };
1EE52B591596B3FC006DC938 /* StatisticsBasedScorer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1EE52B581596B3FC006DC938 /* StatisticsBasedScorer.cpp */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
@ -207,6 +209,8 @@
1E39621F1594CFF9006FE978 /* Permutation.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = Permutation.h; path = ../../mert/Permutation.h; sourceTree = "<group>"; };
1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = SentenceLevelScorer.cpp; path = ../../mert/SentenceLevelScorer.cpp; sourceTree = "<group>"; };
1E3962241594D12C006FE978 /* SentenceLevelScorer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = SentenceLevelScorer.h; path = ../../mert/SentenceLevelScorer.h; sourceTree = "<group>"; };
1EE52B551596B3E4006DC938 /* StatisticsBasedScorer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = StatisticsBasedScorer.h; path = ../../mert/StatisticsBasedScorer.h; sourceTree = "<group>"; };
1EE52B581596B3FC006DC938 /* StatisticsBasedScorer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = StatisticsBasedScorer.cpp; path = ../../mert/StatisticsBasedScorer.cpp; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -223,6 +227,8 @@
1E2CCF2815939E2D00D858D1 = {
isa = PBXGroup;
children = (
1EE52B581596B3FC006DC938 /* StatisticsBasedScorer.cpp */,
1EE52B551596B3E4006DC938 /* StatisticsBasedScorer.h */,
1E3962241594D12C006FE978 /* SentenceLevelScorer.h */,
1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */,
1E39621E1594CFF9006FE978 /* Permutation.cpp */,
@ -401,6 +407,7 @@
1E39621C1594CFD1006FE978 /* PermutationScorer.h in Headers */,
1E3962211594CFF9006FE978 /* Permutation.h in Headers */,
1E3962251594D12C006FE978 /* SentenceLevelScorer.h in Headers */,
1EE52B561596B3E4006DC938 /* StatisticsBasedScorer.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -497,6 +504,7 @@
1E39621B1594CFD1006FE978 /* PermutationScorer.cpp in Sources */,
1E3962201594CFF9006FE978 /* Permutation.cpp in Sources */,
1E3962231594D0FF006FE978 /* SentenceLevelScorer.cpp in Sources */,
1EE52B591596B3FC006DC938 /* StatisticsBasedScorer.cpp in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

View File

@ -162,7 +162,7 @@ void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
entry.set(stats);
}
float BleuScorer::calculateScore(const vector<int>& comps) const
statscore_t BleuScorer::calculateScore(const vector<int>& comps) const
{
CHECK(comps.size() == kBleuNgramOrder * 2 + 1);

View File

@ -7,7 +7,7 @@
#include "Types.h"
#include "ScoreData.h"
#include "Scorer.h"
#include "StatisticsBasedScorer.h"
#include "ScopedVector.h"
const int kBleuNgramOrder = 4;
@ -32,7 +32,7 @@ public:
virtual void setReferenceFiles(const std::vector<std::string>& referenceFiles);
virtual void prepareStats(std::size_t sid, const std::string& text, ScoreStats& entry);
virtual float calculateScore(const std::vector<int>& comps) const;
virtual statscore_t calculateScore(const std::vector<int>& comps) const;
virtual std::size_t NumberOfScores() const { return 2 * kBleuNgramOrder + 1; }
int CalcReferenceLength(std::size_t sentence_id, std::size_t length);

View File

@ -4,7 +4,7 @@
#include <string>
#include <vector>
#include "Types.h"
#include "Scorer.h"
#include "StatisticsBasedScorer.h"
/**
* CderScorer class can compute both CDER and WER metric.

View File

@ -4,7 +4,7 @@
#include <string>
#include <vector>
#include "Scorer.h"
#include "StatisticsBasedScorer.h"
class PerScorer;
class ScoreStats;

View File

@ -5,7 +5,7 @@
#include <string>
#include <vector>
#include "Types.h"
#include "Scorer.h"
#include "StatisticsBasedScorer.h"
class ScoreStats;

View File

@ -6,7 +6,7 @@ using namespace std;
const int PermutationScorer::SCORE_PRECISION = 5;
PermutationScorer::PermutationScorer(const string &distanceMetric, const string &config)
:SentenceLevelScorer(distanceMetric,config)
:StatisticsBasedScorer(distanceMetric,config)
{
//configure regularisation
@ -206,14 +206,8 @@ void PermutationScorer::prepareStats(size_t sid, const string& text, ScoreStats&
//cout << tempStream.str();
}
void PermutationScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const
{
assert(false);
}
//Will just be final score
statscore_t PermutationScorer::calculateScore(const vector<statscore_t>& comps)
statscore_t PermutationScorer::calculateScore(const vector<int>& comps) const
{
//cerr << "*******PermutationScorer::calculateScore" ;
//cerr << " " << comps[0] << endl;

View File

@ -15,12 +15,12 @@
#include "ScoreData.h"
#include "Scorer.h"
#include "Permutation.h"
#include "SentenceLevelScorer.h"
#include "StatisticsBasedScorer.h"
/**
* Permutation
**/
class PermutationScorer: public SentenceLevelScorer
class PermutationScorer: public StatisticsBasedScorer
{
public:
@ -39,11 +39,9 @@ public:
return true;
};
void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const;
protected:
statscore_t calculateScore(const std::vector<statscore_t>& scores);
statscore_t calculateScore(const std::vector<int>& scores) const;
PermutationScorer(const PermutationScorer&);
~PermutationScorer() {};
PermutationScorer& operator=(const PermutationScorer&);

View File

@ -130,110 +130,10 @@ string Scorer::applyFilter(const string& sentence) const
}
}
StatisticsBasedScorer::StatisticsBasedScorer(const string& name, const string& config)
: Scorer(name,config) {
//configure regularisation
static string KEY_TYPE = "regtype";
static string KEY_WINDOW = "regwin";
static string KEY_CASE = "case";
static string TYPE_NONE = "none";
static string TYPE_AVERAGE = "average";
static string TYPE_MINIMUM = "min";
static string TRUE = "true";
static string FALSE = "false";
string type = getConfig(KEY_TYPE,TYPE_NONE);
if (type == TYPE_NONE) {
m_regularization_type = NONE;
} else if (type == TYPE_AVERAGE) {
m_regularization_type = AVERAGE;
} else if (type == TYPE_MINIMUM) {
m_regularization_type = MINIMUM;
} else {
throw runtime_error("Unknown scorer regularisation strategy: " + type);
}
// cerr << "Using scorer regularisation strategy: " << type << endl;
const string& window = getConfig(KEY_WINDOW, "0");
m_regularization_window = atoi(window.c_str());
// cerr << "Using scorer regularisation window: " << m_regularization_window << endl;
const string& preserve_case = getConfig(KEY_CASE,TRUE);
if (preserve_case == TRUE) {
m_enable_preserve_case = true;
} else if (preserve_case == FALSE) {
m_enable_preserve_case = false;
}
// cerr << "Using case preservation: " << m_enable_preserve_case << endl;
float Scorer::score(const candidates_t& candidates) const {
diffs_t diffs;
statscores_t scores;
score(candidates, diffs, scores);
return scores[0];
}
void StatisticsBasedScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const
{
if (!m_score_data) {
throw runtime_error("Score data not loaded");
}
// calculate the score for the candidates
if (m_score_data->size() == 0) {
throw runtime_error("Score data is empty");
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
}
int numCounts = m_score_data->get(0,candidates[0]).size();
vector<int> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = m_score_data->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
}
}
scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates);
// apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
int diff = m_score_data->get(sid,nid).get(k)
- m_score_data->get(sid,last_nid).get(k);
totals[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(calculateScore(totals));
}
// Regularisation. This can either be none, or the min or average as described in
// Cer, Jurafsky and Manning at WMT08.
if (m_regularization_type == NONE || m_regularization_window <= 0) {
// no regularisation
return;
}
// window size specifies the +/- in each direction
statscores_t raw_scores(scores); // copy scores
for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0;
if (i >= m_regularization_window) {
start = i - m_regularization_window;
}
const size_t end = min(scores.size(), i + m_regularization_window + 1);
if (m_regularization_type == AVERAGE) {
scores[i] = score_average(raw_scores,start,end);
} else {
scores[i] = score_min(raw_scores,start,end);
}
}
}

View File

@ -79,12 +79,7 @@ class Scorer
* Calculate the score of the sentences corresponding to the list of candidate
* indices. Each index indicates the 1-best choice from the n-best list.
*/
float score(const candidates_t& candidates) const {
diffs_t diffs;
statscores_t scores;
score(candidates, diffs, scores);
return scores[0];
}
float score(const candidates_t& candidates) const;
const std::string& getName() const {
return m_name;
@ -167,36 +162,6 @@ class Scorer
};
/**
* Abstract base class for Scorers that work by adding statistics across all
* outout sentences, then apply some formula, e.g., BLEU, PER.
*/
class StatisticsBasedScorer : public Scorer
{
public:
StatisticsBasedScorer(const std::string& name, const std::string& config);
virtual ~StatisticsBasedScorer() {}
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const;
protected:
enum RegularisationType {
NONE,
AVERAGE,
MINIMUM
};
/**
* Calculate the actual score.
*/
virtual statscore_t calculateScore(const std::vector<int>& totals) const = 0;
// regularisation
RegularisationType m_regularization_type;
std::size_t m_regularization_window;
};
namespace {
//regularisation strategies
@ -227,4 +192,5 @@ namespace {
} // namespace
#endif // MERT_SCORER_H_

View File

@ -15,6 +15,7 @@
// However, currently SemposScorer uses a bunch of typedefs, which are
// used in SemposScorer as well as inherited SemposOverlapping classes.
#include "SemposOverlapping.h"
#include "StatisticsBasedScorer.h"
/**
* This class represents sempos based metrics.

View File

@ -0,0 +1,119 @@
//
// StatisticsBasedScorer.cpp
// mert_lib
//
// Created by Hieu Hoang on 23/06/2012.
// Copyright 2012 __MyCompanyName__. All rights reserved.
//
#include <iostream>
#include "StatisticsBasedScorer.h"
using namespace std;
StatisticsBasedScorer::StatisticsBasedScorer(const string& name, const string& config)
: Scorer(name,config) {
//configure regularisation
static string KEY_TYPE = "regtype";
static string KEY_WINDOW = "regwin";
static string KEY_CASE = "case";
static string TYPE_NONE = "none";
static string TYPE_AVERAGE = "average";
static string TYPE_MINIMUM = "min";
static string TRUE = "true";
static string FALSE = "false";
string type = getConfig(KEY_TYPE,TYPE_NONE);
if (type == TYPE_NONE) {
m_regularization_type = NONE;
} else if (type == TYPE_AVERAGE) {
m_regularization_type = AVERAGE;
} else if (type == TYPE_MINIMUM) {
m_regularization_type = MINIMUM;
} else {
throw runtime_error("Unknown scorer regularisation strategy: " + type);
}
// cerr << "Using scorer regularisation strategy: " << type << endl;
const string& window = getConfig(KEY_WINDOW, "0");
m_regularization_window = atoi(window.c_str());
// cerr << "Using scorer regularisation window: " << m_regularization_window << endl;
const string& preserve_case = getConfig(KEY_CASE,TRUE);
if (preserve_case == TRUE) {
m_enable_preserve_case = true;
} else if (preserve_case == FALSE) {
m_enable_preserve_case = false;
}
// cerr << "Using case preservation: " << m_enable_preserve_case << endl;
}
void StatisticsBasedScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const
{
if (!m_score_data) {
throw runtime_error("Score data not loaded");
}
// calculate the score for the candidates
if (m_score_data->size() == 0) {
throw runtime_error("Score data is empty");
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
}
int numCounts = m_score_data->get(0,candidates[0]).size();
vector<int> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = m_score_data->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
}
}
scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates);
// apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
int diff = m_score_data->get(sid,nid).get(k)
- m_score_data->get(sid,last_nid).get(k);
totals[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(calculateScore(totals));
}
// Regularisation. This can either be none, or the min or average as described in
// Cer, Jurafsky and Manning at WMT08.
if (m_regularization_type == NONE || m_regularization_window <= 0) {
// no regularisation
return;
}
// window size specifies the +/- in each direction
statscores_t raw_scores(scores); // copy scores
for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0;
if (i >= m_regularization_window) {
start = i - m_regularization_window;
}
const size_t end = min(scores.size(), i + m_regularization_window + 1);
if (m_regularization_type == AVERAGE) {
scores[i] = score_average(raw_scores,start,end);
} else {
scores[i] = score_min(raw_scores,start,end);
}
}
}

View File

@ -0,0 +1,45 @@
//
// StatisticsBasedScorer.h
// mert_lib
//
// Created by Hieu Hoang on 23/06/2012.
// Copyright 2012 __MyCompanyName__. All rights reserved.
//
#ifndef mert_lib_StatisticsBasedScorer_h
#define mert_lib_StatisticsBasedScorer_h
#include "Scorer.h"
/**
* Abstract base class for Scorers that work by adding statistics across all
* outout sentences, then apply some formula, e.g., BLEU, PER.
*/
class StatisticsBasedScorer : public Scorer
{
public:
StatisticsBasedScorer(const std::string& name, const std::string& config);
virtual ~StatisticsBasedScorer() {}
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) const;
protected:
enum RegularisationType {
NONE,
AVERAGE,
MINIMUM
};
/**
* Calculate the actual score.
*/
virtual statscore_t calculateScore(const std::vector<int>& totals) const = 0;
// regularisation
RegularisationType m_regularization_type;
std::size_t m_regularization_window;
};
#endif

View File

@ -6,7 +6,7 @@
#include <vector>
#include "Types.h"
#include "Scorer.h"
#include "StatisticsBasedScorer.h"
class ScoreStats;