merge Lexi Birch's LRScore from mert_mtm5 branch

This commit is contained in:
Hieu Hoang 2012-06-22 18:19:16 +01:00
parent db06d9cf65
commit 7d19fe13ae
12 changed files with 922 additions and 30 deletions

View File

@ -99,6 +99,12 @@
1E2CD02115939E5D00D858D1 /* Util.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E2CCFB415939E5D00D858D1 /* Util.h */; };
1E2CD02315939E5D00D858D1 /* Vocabulary.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E2CCFB615939E5D00D858D1 /* Vocabulary.cpp */; };
1E2CD02415939E5D00D858D1 /* Vocabulary.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E2CCFB715939E5D00D858D1 /* Vocabulary.h */; };
1E39621B1594CFD1006FE978 /* PermutationScorer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E3962191594CFD1006FE978 /* PermutationScorer.cpp */; };
1E39621C1594CFD1006FE978 /* PermutationScorer.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E39621A1594CFD1006FE978 /* PermutationScorer.h */; };
1E3962201594CFF9006FE978 /* Permutation.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E39621E1594CFF9006FE978 /* Permutation.cpp */; };
1E3962211594CFF9006FE978 /* Permutation.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E39621F1594CFF9006FE978 /* Permutation.h */; };
1E3962231594D0FF006FE978 /* SentenceLevelScorer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */; };
1E3962251594D12C006FE978 /* SentenceLevelScorer.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E3962241594D12C006FE978 /* SentenceLevelScorer.h */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
@ -195,6 +201,12 @@
1E2CCFB415939E5D00D858D1 /* Util.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = Util.h; path = ../../mert/Util.h; sourceTree = "<group>"; };
1E2CCFB615939E5D00D858D1 /* Vocabulary.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = Vocabulary.cpp; path = ../../mert/Vocabulary.cpp; sourceTree = "<group>"; };
1E2CCFB715939E5D00D858D1 /* Vocabulary.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = Vocabulary.h; path = ../../mert/Vocabulary.h; sourceTree = "<group>"; };
1E3962191594CFD1006FE978 /* PermutationScorer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = PermutationScorer.cpp; path = ../../mert/PermutationScorer.cpp; sourceTree = "<group>"; };
1E39621A1594CFD1006FE978 /* PermutationScorer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = PermutationScorer.h; path = ../../mert/PermutationScorer.h; sourceTree = "<group>"; };
1E39621E1594CFF9006FE978 /* Permutation.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = Permutation.cpp; path = ../../mert/Permutation.cpp; sourceTree = "<group>"; };
1E39621F1594CFF9006FE978 /* Permutation.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = Permutation.h; path = ../../mert/Permutation.h; sourceTree = "<group>"; };
1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = SentenceLevelScorer.cpp; path = ../../mert/SentenceLevelScorer.cpp; sourceTree = "<group>"; };
1E3962241594D12C006FE978 /* SentenceLevelScorer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = SentenceLevelScorer.h; path = ../../mert/SentenceLevelScorer.h; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -211,6 +223,12 @@
1E2CCF2815939E2D00D858D1 = {
isa = PBXGroup;
children = (
1E3962241594D12C006FE978 /* SentenceLevelScorer.h */,
1E3962221594D0FF006FE978 /* SentenceLevelScorer.cpp */,
1E39621E1594CFF9006FE978 /* Permutation.cpp */,
1E39621F1594CFF9006FE978 /* Permutation.h */,
1E3962191594CFD1006FE978 /* PermutationScorer.cpp */,
1E39621A1594CFD1006FE978 /* PermutationScorer.h */,
1E2CCF3A15939E5D00D858D1 /* BleuScorer.cpp */,
1E2CCF3B15939E5D00D858D1 /* BleuScorer.h */,
1E2CCF3D15939E5D00D858D1 /* CderScorer.cpp */,
@ -380,6 +398,9 @@
1E2CD01F15939E5D00D858D1 /* Types.h in Headers */,
1E2CD02115939E5D00D858D1 /* Util.h in Headers */,
1E2CD02415939E5D00D858D1 /* Vocabulary.h in Headers */,
1E39621C1594CFD1006FE978 /* PermutationScorer.h in Headers */,
1E3962211594CFF9006FE978 /* Permutation.h in Headers */,
1E3962251594D12C006FE978 /* SentenceLevelScorer.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -473,6 +494,9 @@
1E2CD01C15939E5D00D858D1 /* Timer.cpp in Sources */,
1E2CD02015939E5D00D858D1 /* Util.cpp in Sources */,
1E2CD02315939E5D00D858D1 /* Vocabulary.cpp in Sources */,
1E39621B1594CFD1006FE978 /* PermutationScorer.cpp in Sources */,
1E3962201594CFF9006FE978 /* Permutation.cpp in Sources */,
1E3962231594D0FF006FE978 /* SentenceLevelScorer.cpp in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

320
mert/Permutation.cpp Normal file
View File

@ -0,0 +1,320 @@
/*
* Permutation.cpp
* met - Minimum Error Training
*
* Created by Alexandra Birch 18/11/09.
*
*/
#include <fstream>
#include <sstream>
#include <math.h>
#include "Permutation.h"
#include "Util.h"
using namespace std;
Permutation::Permutation(const string &alignment, const int sourceLength, const int targetLength )
{
if (sourceLength > 0) {
set(alignment, sourceLength);
}
m_targetLength = targetLength;
}
size_t Permutation::getLength() const
{
return int(m_array.size());
}
void Permutation::dump() const
{
int j=0;
for (vector<int>::const_iterator i = m_array.begin(); i !=m_array.end(); i++) {
cout << "(";
cout << j << ":" << *i ;
cout << "), ";
j++;
}
cout << endl;
}
//Sent alignment string
//Eg: "0-1 0-0 1-2 3-0 4-5 6-7 "
// Inidiviual word alignments which can be one-one,
// or null aligned, or many-many. The format is sourcepos - targetpos
//Its the output of the berkley aligner subtracting 1 from each number
//sourceLength needed because last source words might not be aligned
void Permutation::set(const string & alignment,const int sourceLength)
{
//cout << "******** Permutation::set :" << alignment << ": len : " << sourceLength <<endl;
if(sourceLength <= 0) {
//not found
cerr << "Source sentence length not positive:"<< sourceLength << endl;
exit(0);
}
if (alignment.length() <= 0) {
//alignment empty - could happen but not good
cerr << "Alignment string empty:"<< alignment << endl;
}
//Tokenise on whitespace
string buf; // Have a buffer string
stringstream ss(alignment); // Insert the string into a stream
vector<string> tokens; // Create vector to hold our words
while (ss >> buf)
tokens.push_back(buf);
vector<int> tempPerm(sourceLength, -1);
//Set tempPerm to have one target position per source position
for (size_t i=0; i<tokens.size(); i++) {
string temp = tokens[i];
int posDelimeter = temp.find("-");
if(posDelimeter == int(string::npos)) {
cerr << "Delimiter not found - :"<< tokens[i] << endl;
exit(1);
}
int sourcePos = atoi((temp.substr(0, posDelimeter)).c_str());
int targetPos = atoi((temp.substr(posDelimeter+1)).c_str());
//cout << "SP:" << sourcePos << " TP:" << targetPos << endl;
if (sourcePos > sourceLength) {
cerr << "Source sentence length:" << sourceLength << " is smaller than alignment source position:" << sourcePos << endl;
exit(1);
}
//If have multiple target pos aligned to one source,
// then ignore all but first alignment
if (tempPerm[sourcePos] == -1 || tempPerm[sourcePos] > targetPos) {
tempPerm[sourcePos] = targetPos;
}
}
//TODO
//Set final permutation in m_array
//Take care of: source - null
// multiple_source - one target
// unaligned target
// Input: 1-9 2-1 4-3 4-4 5-6 6-6 7-6 8-8
// Convert source: 1 2 3 4 5 6 7 8
// target: 9 1 -1 3 6 6 6 8 -> 8 1 2 3 4 5 6 7
// 1st step: Add null aligned source to previous alignment
// target: 9 1 -1 3 6 6 6 8 -> 9 1 1 3 6 6 6 8
int last=0;
m_array.assign(sourceLength, -1);
//get a searcheable index
multimap<int, int> invMap;
multimap<int, int>::iterator it;
//cout << " SourceP -> TargetP " << endl;
for (size_t i=0; i<tempPerm.size(); i++) {
if (tempPerm[i] == -1) {
tempPerm[i] = last;
} else {
last = tempPerm[i];
}
//cout << i << " -> " << tempPerm[i] << endl;
//Key is target pos, value is source pos
invMap.insert(pair<int,int>(tempPerm[i],int(i)));
}
// 2nd step: Get target into index of multimap and sort
// Convert source: 1 2 3 4 5 6 7 8
// target: 9 1 0 3 6 6 6 8 -> 0 1 3 6 6 6 8 9
// source: 3 2 4 5 6 7 8 1
int i=0;
//cout << " TargetP => SourceP : TargetIndex " << endl;
for ( it=invMap.begin() ; it != invMap.end(); it++ ) {
//cout << (*it).first << " => " << (*it).second << " : " << i << endl;
//find source position
m_array[(*it).second] = i;
i++;
}
bool ok = checkValidPermutation(m_array);
//dump();
if (!ok) {
throw runtime_error(" Created invalid permutation");
}
}
//Static
vector<int> Permutation::invert(const vector<int> & inVector)
{
vector<int> outVector(inVector.size());
for (size_t i=0; i<inVector.size(); i++) {
outVector[inVector[i]] = int(i);
}
return outVector;
}
//Static
//Permutations start at 0
bool Permutation::checkValidPermutation(vector<int> const & inVector)
{
vector<int> test(inVector.size(),-1);
for (size_t i=0; i< inVector.size(); i++) {
//No multiple entries of same value allowed
if (test[inVector[i]] > -1) {
cerr << "Permutation error: multiple entries of same value\n" << endl;
return false;
}
test[inVector[i]] ++;
}
for (size_t i=0; i<inVector.size(); i++) {
//No holes allowed
if (test[inVector[i]] == -1) {
cerr << "Permutation error: missing values\n" << endl;
return false;
}
}
return true;
}
//TODO default to HAMMING
//Note: it returns the distance that is not normalised
float Permutation::distance(const Permutation &permCompare, const distanceMetric_t &type) const
{
float score=0;
//cout << "*****Permutation::distance" <<endl;
//cout << "Ref:" << endl;
//dump();
//cout << "Comp:" << endl;
//permCompare.dump();
if (type == HAMMING_DISTANCE) {
score = calculateHamming(permCompare);
} else if (type == KENDALL_DISTANCE) {
score = calculateKendall(permCompare);
} else {
throw runtime_error("Distance type not valid");
}
float brevityPenalty = 1.0 - (float) permCompare.getTargetLength()/getTargetLength() ;//reflength divided by trans length
if (brevityPenalty < 0.0) {
score = score * exp(brevityPenalty);
}
//cout << "Distance type:" << type << endl;
//cout << "Score: "<< score << endl;
return score;
}
float Permutation::calculateHamming(const Permutation & compare) const
{
float score=0;
vector<int> compareArray = compare.getArray();
if (getLength() != compare.getLength()) {
cerr << "1stperm: " << getLength() << " 2ndperm: " << compare.getLength() << endl;
throw runtime_error("Length of permutations not equal");
}
if (getLength() == 0) {
cerr << "Empty permutation" << endl;
return 0;
}
for (size_t i=0; i<getLength(); i++) {
if (m_array[i] != compareArray[i]) {
score++;
}
}
score = 1 - (score / getLength());
return score;
}
float Permutation::calculateKendall(const Permutation & compare) const
{
float score=0;
vector<int> compareArray = compare.getArray();
if (getLength() != compare.getLength()) {
cerr << "1stperm: " << getLength() << " 2ndperm: " << compare.getLength() << endl;
throw runtime_error("Length of permutations not equal");
}
if (getLength() == 0) {
cerr << "Empty permutation" << endl;
return 0;
}
for (size_t i=0; i<getLength(); i++) {
for (size_t j=0; j<getLength(); j++) {
if ((m_array[i] < m_array[j]) && (compareArray[i] > compareArray[j])) {
score++;
}
}
}
score = (score / ((getLength()*getLength() - getLength()) /2 ) );
//Adjusted Kendall's tau correlates better with human judgements
score = sqrt (score);
score = 1 - score;
return score;
}
vector<int> Permutation::getArray() const
{
vector<int> ret = m_array;
return ret;
}
//Static
//This function is called with test which is
// the 5th field in moses nbest output when called with -include-alignment-in-n-best
//eg. 0=0 1-2=1-2 3=3 4=4 5=5 6=6 7-9=7-8 10=9 11-13=10-11 (source-target)
string Permutation::convertMosesToStandard(string const & alignment)
{
if (alignment.length() == 0) {
cerr << "Alignment input string empty" << endl;
}
string working = alignment;
string out;
stringstream oss;
while (working.length() > 0) {
string align;
getNextPound(working,align," ");
//If found an alignment
if (align.length() > 0) {
size_t posDelimeter = align.find("=");
if(posDelimeter== string::npos) {
cerr << "Delimiter not found = :"<< align << endl;
exit(0);
}
int firstSourcePos,lastSourcePos,firstTargetPos,lastTargetPos;
string sourcePoss = align.substr(0, posDelimeter);
string targetPoss = align.substr(posDelimeter+1);
posDelimeter = sourcePoss.find("-");
if(posDelimeter < string::npos) {
firstSourcePos = atoi((sourcePoss.substr(0, posDelimeter)).c_str());
lastSourcePos = atoi((sourcePoss.substr(posDelimeter+1)).c_str());
} else {
firstSourcePos = atoi(sourcePoss.c_str());
lastSourcePos = firstSourcePos;
}
posDelimeter = targetPoss.find("-");
if(posDelimeter < string::npos) {
firstTargetPos = atoi((targetPoss.substr(0, posDelimeter)).c_str());
lastTargetPos = atoi((targetPoss.substr(posDelimeter+1)).c_str());
} else {
firstTargetPos = atoi(targetPoss.c_str());
lastTargetPos = firstTargetPos;
}
for (int i = firstSourcePos; i <= lastSourcePos; i++) {
for (int j = firstTargetPos; j <= lastTargetPos; j++) {
oss << i << "-" << j << " ";
}
}
} //else case where two spaces ?
}
out = oss.str();
//cout << "ConverttoStandard: " << out << endl;
return out;
}

66
mert/Permutation.h Normal file
View File

@ -0,0 +1,66 @@
/*
* Permutation.h
* met - Minimum Error Training
*
* Created by Alexandra Birch 18 Nov 2009.
*
*/
#ifndef PERMUTATION_H
#define PERMUTATION_H
#include <limits>
#include <vector>
#include <iostream>
#include <fstream>
#include "Util.h"
class Permutation
{
public:
//Can be HAMMING_DISTANCE or KENDALLS_DISTANCE
Permutation(const std::string &alignment = std::string(), const int sourceLength = 0, const int targetLength = 0 );
~Permutation() {};
inline void clear() {
m_array.clear();
}
inline size_t size() {
return m_array.size();
}
void set(const std::string &alignment,const int sourceLength);
float distance(const Permutation &permCompare, const distanceMetric_t &strategy = HAMMING_DISTANCE) const;
//Const
void dump() const;
size_t getLength() const;
std::vector<int> getArray() const;
int getTargetLength() const {
return m_targetLength;
}
//Static
static std::string convertMosesToStandard(std::string const & alignment);
static std::vector<int> invert(std::vector<int> const & inVector);
static bool checkValidPermutation(std::vector<int> const & inVector);
protected:
std::vector<int> m_array;
int m_targetLength;
float calculateHamming(const Permutation & compare) const;
float calculateKendall(const Permutation & compare) const;
private:
};
#endif

215
mert/PermutationScorer.cpp Normal file
View File

@ -0,0 +1,215 @@
#include "PermutationScorer.h"
using namespace std;
const int PermutationScorer::SCORE_PRECISION = 5;
PermutationScorer::PermutationScorer(const string &distanceMetric, const string &config)
:SentenceLevelScorer(distanceMetric,config)
{
//configure regularisation
static string KEY_REFCHOICE = "refchoice";
static string REFCHOICE_AVERAGE = "average";
static string REFCHOICE_CLOSEST = "closest";
string refchoice = getConfig(KEY_REFCHOICE,REFCHOICE_CLOSEST);
if (refchoice == REFCHOICE_AVERAGE) {
m_refChoiceStrategy = REFERENCE_CHOICE_AVERAGE;
} else if (refchoice == REFCHOICE_CLOSEST) {
m_refChoiceStrategy = REFERENCE_CHOICE_CLOSEST;
} else {
throw runtime_error("Unknown reference choice strategy: " + refchoice);
}
cerr << "Using reference choice strategy: " << refchoice << endl;
if (distanceMetric.compare("HAMMING") == 0) {
m_distanceMetric = HAMMING_DISTANCE;
} else if (distanceMetric.compare("KENDALL") == 0) {
m_distanceMetric = KENDALL_DISTANCE;
}
cerr << "Using permutation distance metric: " << distanceMetric << endl;
//Get reference alignments from scconfig refalign option
static string KEY_ALIGNMENT_FILES = "refalign";
string refalign = getConfig(KEY_ALIGNMENT_FILES,"");
//cout << refalign << endl;
if (refalign.length() > 0) {
string substring;
while (!refalign.empty()) {
getNextPound(refalign, substring, "+");
m_referenceAlignments.push_back(substring);
}
}
//Get length of source sentences read in from scconfig source option
// this is essential for extractor but unneccesary for mert executable
static string KEY_SOURCE_FILE = "source";
string sourceFile = getConfig(KEY_SOURCE_FILE,"");
if (sourceFile.length() > 0) {
cerr << "Loading source sentence lengths from " << sourceFile << endl;
ifstream sourcein(sourceFile.c_str());
if (!sourcein) {
throw runtime_error("Unable to open: " + sourceFile);
}
string line;
while (getline(sourcein,line)) {
size_t wordNumber = 0;
string word;
while(!line.empty()) {
getNextPound(line, word, " ");
wordNumber++;
}
m_sourceLengths.push_back(wordNumber);
}
sourcein.close();
}
}
void PermutationScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
cout << "*******setReferenceFiles" << endl;
//make sure reference data is clear
m_referencePerms.clear();
vector< vector< int> > targetLengths;
//Just getting target length from reference text file
for (size_t i = 0; i < referenceFiles.size(); ++i) {
vector <int> lengths;
cout << "Loading reference from " << referenceFiles[i] << endl;
ifstream refin(referenceFiles[i].c_str());
if (!refin) {
cerr << "Unable to open: " << referenceFiles[i] << endl;
throw runtime_error("Unable to open alignment file");
}
string line;
while (getline(refin,line)) {
int count = getNumberWords(line);
lengths.push_back(count);
}
targetLengths.push_back(lengths);
}
//load reference data
//NOTE ignoring normal reference file, only using previously saved alignment reference files
for (size_t i = 0; i < m_referenceAlignments.size(); ++i) {
vector<Permutation> referencePerms;
cout << "Loading reference from " << m_referenceAlignments[i] << endl;
ifstream refin(m_referenceAlignments[i].c_str());
if (!refin) {
cerr << "Unable to open: " << m_referenceAlignments[i] << endl;
throw runtime_error("Unable to open alignment file");
}
string line;
size_t sid = 0; //sentence counter
while (getline(refin,line)) {
//cout << line << endl;
//Line needs to be of the format: 0-0 1-1 1-2 etc source-target
Permutation perm(line, m_sourceLengths[sid],targetLengths[i][sid]);
//perm.dump();
referencePerms.push_back(perm);
//check the source sentence length is the same for previous file
if (perm.getLength() != m_sourceLengths[sid]) {
cerr << "Permutation Length: " << perm.getLength() << endl;
cerr << "Source length: " << m_sourceLengths[sid] << " for sid " << sid << endl;
throw runtime_error("Source sentence lengths not the same: ");
}
sid++;
}
m_referencePerms.push_back(referencePerms);
}
}
int PermutationScorer::getNumberWords (const string& text) const
{
int count = 0;
string line = trimStr(text);
if (line.length()>0) {
int pos = line.find(" ");
while (pos!=int(string::npos)) {
count++;
pos = line.find(" ",pos+1);
}
count++;
}
return count;
}
void PermutationScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
//cout << "*******prepareStats" ;
//cout << text << endl;
//cout << sid << endl;
//cout << "Reference0align:" << endl;
//m_referencePerms[0][sid].dump();
string sentence = "";
string align = text;
size_t alignmentData = text.find("|||");
//Get sentence and alignment parts
if(alignmentData != string::npos) {
getNextPound(align,sentence, "|||");
} else {
align = text;
}
int translationLength = getNumberWords(sentence);
//A vector of Permutations for each sentence
vector< vector<Permutation> > nBestPerms;
float distanceValue;
//need to create permutations for each nbest line
string standardFormat = Permutation::convertMosesToStandard(align);
Permutation perm(standardFormat, m_sourceLengths[sid],translationLength);
//perm.dump();
if (m_refChoiceStrategy == REFERENCE_CHOICE_AVERAGE) {
float total = 0;
for (size_t i = 0; i < m_referencePerms.size(); ++i) {
float dist = perm.distance(m_referencePerms[i][sid], m_distanceMetric);
total += dist;
//cout << "Ref number: " << i << " distance: " << dist << endl;
}
float mean = (float)total/m_referencePerms.size();
//cout << "MultRef strategy AVERAGE: total " << total << " mean " << mean << " number " << m_referencePerms.size() << endl;
distanceValue = mean;
} else if (m_refChoiceStrategy == REFERENCE_CHOICE_CLOSEST) {
float max_val = 0;
for (size_t i = 0; i < m_referencePerms.size(); ++i) {
//look for the closest reference
float value = perm.distance(m_referencePerms[i][sid], m_distanceMetric);
//cout << "Ref number: " << i << " distance: " << value << endl;
if (value > max_val) {
max_val = value;
}
}
distanceValue = max_val;
//cout << "MultRef strategy CLOSEST: max_val " << distanceValue << endl;
} else {
throw runtime_error("Unsupported reflength strategy");
}
//SCOREROUT eg: 0.04546
ostringstream tempStream;
tempStream.precision(SCORE_PRECISION);
tempStream << distanceValue;
string str = tempStream.str();
entry.set(str);
//cout << tempStream.str();
}
//Will just be final score
statscore_t PermutationScorer::calculateScore(const vector<statscore_t>& comps)
{
//cerr << "*******PermutationScorer::calculateScore" ;
//cerr << " " << comps[0] << endl;
return comps[0];
}

64
mert/PermutationScorer.h Normal file
View File

@ -0,0 +1,64 @@
#ifndef __PERMUTATIONSCORER_H__
#define __PERMUTATIONSCORER_H__
#include <algorithm>
#include <cmath>
#include <iostream>
#include <iterator>
#include <set>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include <limits.h>
#include "Types.h"
#include "ScoreData.h"
#include "Scorer.h"
#include "Permutation.h"
#include "SentenceLevelScorer.h"
/**
* Permutation
**/
class PermutationScorer: public SentenceLevelScorer
{
public:
PermutationScorer(const std::string &distanceMetric = "HAMMING",
const std::string &config = std::string());
void setReferenceFiles(const std::vector<std::string>& referenceFiles);
void prepareStats(size_t sid, const std::string& text, ScoreStats& entry);
static const int SCORE_PRECISION;
size_t NumberOfScores() const {
//cerr << "PermutationScorer number of scores: 1" << endl;
return 1;
};
bool useAlignment() const {
//cout << "PermutationScorer::useAlignment returning true" << endl;
return true;
};
protected:
statscore_t calculateScore(const std::vector<statscore_t>& scores);
PermutationScorer(const PermutationScorer&);
~PermutationScorer() {};
PermutationScorer& operator=(const PermutationScorer&);
int getNumberWords (const std::string & line) const;
distanceMetricReferenceChoice_t m_refChoiceStrategy;
distanceMetric_t m_distanceMetric;
// data extracted from reference files
// A vector of permutations for each reference file
std::vector< std::vector<Permutation> > m_referencePerms;
std::vector<size_t> m_sourceLengths;
std::vector<std::string> m_referenceAlignments;
private:
};
//TODO need to read in floats for scores - necessary for selecting mean reference strategy and for BLEU?
#endif //__PERMUTATIONSCORER_H

View File

@ -8,36 +8,6 @@
using namespace std;
namespace {
//regularisation strategies
inline float score_min(const statscores_t& scores, size_t start, size_t end)
{
float min = numeric_limits<float>::max();
for (size_t i = start; i < end; ++i) {
if (scores[i] < min) {
min = scores[i];
}
}
return min;
}
inline float score_average(const statscores_t& scores, size_t start, size_t end)
{
if ((end - start) < 1) {
// this shouldn't happen
return 0;
}
float total = 0;
for (size_t j = start; j < end; ++j) {
total += scores[j];
}
return total / (end - start);
}
} // namespace
Scorer::Scorer(const string& name, const string& config)
: m_name(name),
m_vocab(mert::VocabularyFactory::GetVocabulary()),

View File

@ -18,6 +18,8 @@ class Vocabulary;
} // namespace mert
enum ScorerRegularisationStrategy {REG_NONE, REG_AVERAGE, REG_MINIMUM};
/**
* Superclass of all scorers and dummy implementation.
*
@ -195,4 +197,34 @@ class StatisticsBasedScorer : public Scorer
std::size_t m_regularization_window;
};
namespace {
//regularisation strategies
inline float score_min(const statscores_t& scores, size_t start, size_t end)
{
float min = std::numeric_limits<float>::max();
for (size_t i = start; i < end; ++i) {
if (scores[i] < min) {
min = scores[i];
}
}
return min;
}
inline float score_average(const statscores_t& scores, size_t start, size_t end)
{
if ((end - start) < 1) {
// this shouldn't happen
return 0;
}
float total = 0;
for (size_t j = start; j < end; ++j) {
total += scores[j];
}
return total / (end - start);
}
} // namespace
#endif // MERT_SCORER_H_

View File

@ -9,6 +9,7 @@
#include "MergeScorer.h"
#include "InterpolatedScorer.h"
#include "SemposScorer.h"
#include "PermutationScorer.h"
using namespace std;
@ -21,6 +22,7 @@ vector<string> ScorerFactory::getTypes() {
types.push_back(string("WER"));
types.push_back(string("MERGE"));
types.push_back(string("SEMPOS"));
types.push_back(string("LRSCORE"));
return types;
}
@ -40,6 +42,8 @@ Scorer* ScorerFactory::getScorer(const string& type, const string& config) {
return new SemposScorer(config);
} else if (type == "MERGE") {
return new MergeScorer(config);
} else if (type == "MERGE") {
return new PermutationScorer(config);
} else {
if (type.find(',') != string::npos) {
return new InterpolatedScorer(type, config);

View File

@ -0,0 +1,102 @@
//
// SentenceLevelScorer.cpp
// mert_lib
//
// Created by Hieu Hoang on 22/06/2012.
// Copyright 2012 __MyCompanyName__. All rights reserved.
//
#include <iostream>
#include "SentenceLevelScorer.h"
using namespace std;
/** The sentence level scores have already been calculated, just need to average them
and include the differences. Allows scores which are floats **/
void SentenceLevelScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores)
{
//cout << "*******SentenceLevelScorer::score" << endl;
if (!m_score_data) {
throw runtime_error("Score data not loaded");
}
//calculate the score for the candidates
if (m_score_data->size() == 0) {
throw runtime_error("Score data is empty");
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
}
int numCounts = m_score_data->get(0,candidates[0]).size();
vector<float> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
//cout << " i " << i << " candi " << candidates[i] ;
ScoreStats stats = m_score_data->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
//Add up scores for all sentences, would normally be just one score
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
//cout << " stats " << stats.get(k) ;
}
//cout << endl;
}
//take average
for (size_t k = 0; k < totals.size(); ++k) {
//cout << "totals = " << totals[k] << endl;
//cout << "cand = " << candidates.size() << endl;
totals[k] /= candidates.size();
//cout << "finaltotals = " << totals[k] << endl;
}
scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
//cout << "sid = " << sid << endl;
//cout << "nid = " << nid << endl;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
float diff = m_score_data->get(sid,nid).get(k)
- m_score_data->get(sid,last_nid).get(k);
//cout << "diff = " << diff << endl;
totals[k] += diff/candidates.size();
//cout << "totals = " << totals[k] << endl;
}
last_candidates[sid] = nid;
}
scores.push_back(calculateScore(totals));
}
//regularisation. This can either be none, or the min or average as described in
//Cer, Jurafsky and Manning at WMT08
if (_regularisationStrategy == REG_NONE || _regularisationWindow <= 0) {
//no regularisation
return;
}
//window size specifies the +/- in each direction
statscores_t raw_scores(scores);//copy scores
for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0;
if (i >= _regularisationWindow) {
start = i - _regularisationWindow;
}
size_t end = min(scores.size(), i + _regularisationWindow+1);
if (_regularisationStrategy == REG_AVERAGE) {
scores[i] = score_average(raw_scores,start,end);
} else {
scores[i] = score_min(raw_scores,start,end);
}
}
}

View File

@ -0,0 +1,83 @@
//
// SentenceLevelScorer.h
// mert_lib
//
// Created by Hieu Hoang on 22/06/2012.
// Copyright 2012 __MyCompanyName__. All rights reserved.
//
#ifndef mert_lib_SentenceLevelScorer_h
#define mert_lib_SentenceLevelScorer_h
#include "Scorer.h"
#include <string>
#include <vector>
#include <vector>
#include <boost/spirit/home/support/detail/lexer/runtime_error.hpp>
/**
* Abstract base class for scorers that work by using sentence level
* statistics eg. permutation distance metrics **/
class SentenceLevelScorer : public Scorer
{
public:
SentenceLevelScorer(const std::string& name, const std::string& config): Scorer(name,config) {
//configure regularisation
static std::string KEY_TYPE = "regtype";
static std::string KEY_WINDOW = "regwin";
static std::string KEY_CASE = "case";
static std::string TYPE_NONE = "none";
static std::string TYPE_AVERAGE = "average";
static std::string TYPE_MINIMUM = "min";
static std::string TRUE = "true";
static std::string FALSE = "false";
std::string type = getConfig(KEY_TYPE,TYPE_NONE);
if (type == TYPE_NONE) {
_regularisationStrategy = REG_NONE;
} else if (type == TYPE_AVERAGE) {
_regularisationStrategy = REG_AVERAGE;
} else if (type == TYPE_MINIMUM) {
_regularisationStrategy = REG_MINIMUM;
} else {
throw boost::lexer::runtime_error("Unknown scorer regularisation strategy: " + type);
}
std::cerr << "Using scorer regularisation strategy: " << type << std::endl;
std::string window = getConfig(KEY_WINDOW,"0");
_regularisationWindow = atoi(window.c_str());
std::cerr << "Using scorer regularisation window: " << _regularisationWindow << std::endl;
std::string preservecase = getConfig(KEY_CASE,TRUE);
if (preservecase == TRUE) {
m_enable_preserve_case = true;
} else if (preservecase == FALSE) {
m_enable_preserve_case = false;
}
std::cerr << "Using case preservation: " << m_enable_preserve_case << std::endl;
}
~SentenceLevelScorer() {};
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores);
//calculate the actual score
virtual statscore_t calculateScore(const std::vector<statscore_t>& totals) {
return 0;
};
protected:
//regularisation
ScorerRegularisationStrategy _regularisationStrategy;
size_t _regularisationWindow;
};
#endif

View File

@ -39,4 +39,7 @@ typedef std::vector<ScoreArray> scoredata_t;
typedef std::map<std::size_t, std::string> idx2name;
typedef std::map<std::string, std::size_t> name2idx;
typedef enum { HAMMING_DISTANCE=0, KENDALL_DISTANCE } distanceMetric_t;
typedef enum { REFERENCE_CHOICE_AVERAGE=0, REFERENCE_CHOICE_CLOSEST } distanceMetricReferenceChoice_t;
#endif // MERT_TYPE_H_

View File

@ -116,6 +116,15 @@ inline FeatureStatsType ConvertStringToFeatureStatsType(const std::string &str)
return ConvertCharToFeatureStatsType(str.c_str());
}
inline std::string trimStr(const std::string& Src, const std::string& c = " \r\n")
{
unsigned int p2 = Src.find_last_not_of(c);
if (p2 == std::string::npos) return std::string();
unsigned int p1 = Src.find_first_not_of(c);
if (p1 == std::string::npos) p1 = 0;
return Src.substr(p1, (p2-p1)+1);
}
// Utilities to measure decoding time
void ResetUserTime();
void PrintUserTime(const std::string &message);