mosesdecoder/mert/HwcmScorer.cpp

165 lines
5.2 KiB
C++
Raw Normal View History

2014-09-16 14:12:14 +04:00
#include "HwcmScorer.h"
#include <fstream>
#include "ScoreStats.h"
#include "Util.h"
#include "util/tokenize_piece.hh"
// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length.
// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference).
// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics.
using namespace std;
namespace MosesTuning
{
HwcmScorer::HwcmScorer(const string& config)
: StatisticsBasedScorer("HWCM",config) {}
HwcmScorer::~HwcmScorer() {}
void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
// For each line in the reference file, create a tree object
if (referenceFiles.size() != 1) {
throw runtime_error("HWCM only supports a single reference");
}
m_ref_trees.clear();
m_ref_hwc.clear();
ifstream in((referenceFiles[0] + ".trees").c_str());
if (!in) {
throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
}
string line;
while (getline(in,line)) {
line = this->preprocessSentence(line);
TreePointer tree (boost::make_shared<InternalTree>(line));
m_ref_trees.push_back(tree);
vector<map<string, int> > hwc (kHwcmOrder);
vector<string> history(kHwcmOrder);
extractHeadWordChain(tree, history, hwc);
m_ref_hwc.push_back(hwc);
vector<int> totals(kHwcmOrder);
for (size_t i = 0; i < kHwcmOrder; i++) {
for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
totals[i] += it->second;
}
}
m_ref_lengths.push_back(totals);
}
TRACE_ERR(endl);
}
2015-01-14 14:07:42 +03:00
void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
{
2014-09-16 14:12:14 +04:00
if (tree->GetLength() > 0) {
string head = getHead(tree);
if (head.empty()) {
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
extractHeadWordChain(*it, history, hwc);
}
2015-01-14 14:07:42 +03:00
} else {
2014-09-16 14:12:14 +04:00
vector<string> new_history(kHwcmOrder);
new_history[0] = head;
hwc[0][head]++;
for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
if (!history[hist_idx].empty()) {
string chain = history[hist_idx] + " " + head;
hwc[hist_idx+1][chain]++;
if (hist_idx+2 < kHwcmOrder) {
new_history[hist_idx+1] = chain;
}
}
}
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
extractHeadWordChain(*it, new_history, hwc);
}
}
}
}
2015-01-14 14:07:42 +03:00
string HwcmScorer::getHead(TreePointer tree)
{
2014-09-16 14:12:14 +04:00
// assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head)
// if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned
2015-01-14 14:07:42 +03:00
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
2014-09-16 14:12:14 +04:00
TreePointer child = *it;
if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
return child->GetChildren()[0]->GetLabel();
}
}
return "";
}
void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
if (sid >= m_ref_trees.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
string sentence = this->preprocessSentence(text);
// if sentence has '|||', assume that tree is in second position (n-best-list);
// otherwise, assume it is in first position (calling 'evaluate' with tree as reference)
util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
++it;
if (it) {
sentence = it->as_string();
}
TreePointer tree (boost::make_shared<InternalTree>(sentence));
vector<map<string, int> > hwc_test (kHwcmOrder);
vector<string> history(kHwcmOrder);
extractHeadWordChain(tree, history, hwc_test);
ostringstream stats;
for (size_t i = 0; i < kHwcmOrder; i++) {
int correct = 0;
int test_total = 0;
for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
test_total += it->second;
map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
if (it2 != m_ref_hwc[sid][i].end()) {
correct += std::min(it->second, it2->second);
}
}
stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
}
string stats_str = stats.str();
entry.set(stats_str);
}
float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
2014-09-16 14:12:14 +04:00
{
float precision = 0;
float recall = 0;
for (size_t i = 0; i < kHwcmOrder; i++) {
float matches = comps[i*3];
float test_total = comps[1+(i*3)];
float ref_total = comps[2+(i*3)];
if (test_total > 0) {
precision += matches/test_total;
}
if (ref_total > 0) {
recall += matches/ref_total;
}
}
precision /= (float)kHwcmOrder;
recall /= (float)kHwcmOrder;
return (2*precision*recall)/(precision+recall); // f1-score
}
}