mosesdecoder/mert/CderScorer.cpp

143 lines
3.7 KiB
C++
Raw Normal View History

#include "CderScorer.h"
#include <algorithm>
#include <fstream>
#include <stdexcept>
using namespace std;
2013-05-29 21:16:15 +04:00
namespace
{
2013-05-29 21:16:15 +04:00
inline int CalcDistance(int word1, int word2)
{
return word1 == word2 ? 0 : 1;
}
} // namespace
namespace MosesTuning
{
2013-05-29 21:16:15 +04:00
CderScorer::CderScorer(const string& config, bool allowed_long_jumps)
2013-05-29 21:16:15 +04:00
: StatisticsBasedScorer(allowed_long_jumps ? "CDER" : "WER", config),
m_allowed_long_jumps(allowed_long_jumps) {}
CderScorer::~CderScorer() {}
void CderScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
//make sure reference data is clear
m_ref_sentences.clear();
2011-11-12 04:24:19 +04:00
//load reference data
for (size_t rid = 0; rid < referenceFiles.size(); ++rid) {
ifstream refin(referenceFiles[rid].c_str());
if (!refin) {
throw runtime_error("Unable to open: " + referenceFiles[rid]);
}
m_ref_sentences.push_back(vector<sent_t>());
string line;
while (getline(refin,line)) {
line = this->preprocessSentence(line);
2011-11-12 04:24:19 +04:00
sent_t encoded;
TokenizeAndEncode(line, encoded);
m_ref_sentences[rid].push_back(encoded);
}
2011-11-12 04:24:19 +04:00
}
}
void CderScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
string sentence = this->preprocessSentence(text);
vector<ScoreStatsType> stats;
prepareStatsVector(sid, sentence, stats);
entry.set(stats);
}
void CderScorer::prepareStatsVector(size_t sid, const string& text, vector<ScoreStatsType>& stats)
{
2011-11-12 04:24:19 +04:00
sent_t cand;
TokenizeAndEncode(text, cand);
2011-11-12 04:24:19 +04:00
float max = -2;
vector<ScoreStatsType> tmp;
for (size_t rid = 0; rid < m_ref_sentences.size(); ++rid) {
2012-02-26 08:04:27 +04:00
const sent_t& ref = m_ref_sentences[rid][sid];
tmp.clear();
computeCD(cand, ref, tmp);
2013-06-03 18:59:13 +04:00
int score = calculateScore(tmp);
2013-06-05 16:42:56 +04:00
if (rid == 0) {
2011-11-12 04:24:19 +04:00
stats = tmp;
2013-06-03 18:59:13 +04:00
max = score;
2013-06-05 16:42:56 +04:00
} else if (score > max) {
2013-06-03 18:59:13 +04:00
stats = tmp;
max = score;
}
2011-11-12 04:24:19 +04:00
}
}
float CderScorer::calculateScore(const vector<ScoreStatsType>& comps) const
{
2012-02-26 08:04:27 +04:00
if (comps.size() != 2) {
2011-11-12 04:24:19 +04:00
throw runtime_error("Size of stat vector for CDER is not 2");
}
2012-03-18 02:35:56 +04:00
if (comps[1] == 0) return 1.0f;
2012-02-26 08:04:27 +04:00
return 1.0f - (comps[0] / static_cast<float>(comps[1]));
}
2012-02-26 08:04:27 +04:00
void CderScorer::computeCD(const sent_t& cand, const sent_t& ref,
vector<ScoreStatsType>& stats) const
2013-05-29 21:16:15 +04:00
{
2011-11-12 04:24:19 +04:00
int I = cand.size() + 1; // Number of inter-words positions in candidate sentence
int L = ref.size() + 1; // Number of inter-words positions in reference sentence
int l = 0;
// row[i] stores cost of cheapest path from (0,0) to (i,l) in CDER aligment grid.
vector<int>* row = new vector<int>(I);
// Initialization of first row
2013-06-03 18:59:13 +04:00
for (int i = 0; i < I; ++i) (*row)[i] = i;
2013-06-05 16:42:56 +04:00
2013-06-03 18:59:13 +04:00
// For CDER metric, the initialization is different
if (m_allowed_long_jumps) {
for (int i = 1; i < I; ++i) (*row)[i] = 1;
}
2011-11-12 04:24:19 +04:00
// Calculating costs for next row using costs from the previous row.
2013-05-29 21:16:15 +04:00
while (++l < L) {
2011-11-12 04:24:19 +04:00
vector<int>* nextRow = new vector<int>(I);
2013-05-29 21:16:15 +04:00
for (int i = 0; i < I; ++i) {
2011-11-12 04:24:19 +04:00
vector<int> possibleCosts;
if (i > 0) {
possibleCosts.push_back((*nextRow)[i-1] + 1); // Deletion
possibleCosts.push_back((*row)[i-1] + CalcDistance(ref[l-1], cand[i-1])); // Substitution/Identity
2011-11-12 04:24:19 +04:00
}
possibleCosts.push_back((*row)[i] + 1); // Insertion
(*nextRow)[i] = *min_element(possibleCosts.begin(), possibleCosts.end());
}
if (m_allowed_long_jumps) {
// Cost of LongJumps is the same for all in the row
int LJ = 1 + *min_element(nextRow->begin(), nextRow->end());
for (int i = 0; i < I; ++i) {
(*nextRow)[i] = min((*nextRow)[i], LJ); // LongJumps
}
2011-11-12 04:24:19 +04:00
}
delete row;
row = nextRow;
}
2012-02-26 08:04:27 +04:00
stats.resize(2);
2011-11-12 04:24:19 +04:00
stats[0] = *(row->rbegin()); // CD distance is the cost of path from (0,0) to (I,L)
stats[1] = ref.size();
delete row;
}
}