2014-09-16 14:12:14 +04:00
|
|
|
#include "HwcmScorer.h"
|
|
|
|
|
|
|
|
#include <fstream>
|
|
|
|
|
|
|
|
#include "ScoreStats.h"
|
|
|
|
#include "Util.h"
|
|
|
|
|
|
|
|
#include "util/tokenize_piece.hh"
|
|
|
|
|
|
|
|
// HWCM score (Liu and Gildea, 2005). Implements F1 instead of precision for better modelling of hypothesis length.
|
|
|
|
// assumes dependency trees on target side (generated by scripts/training/wrappers/conll2mosesxml.py ; use with option --brackets for reference).
|
|
|
|
// reads reference trees from separate file {REFERENCE_FILE}.trees to support mix of string-based and tree-based metrics.
|
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
|
namespace MosesTuning
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
|
|
HwcmScorer::HwcmScorer(const string& config)
|
|
|
|
: StatisticsBasedScorer("HWCM",config) {}
|
|
|
|
|
|
|
|
HwcmScorer::~HwcmScorer() {}
|
|
|
|
|
|
|
|
void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
|
|
|
|
{
|
|
|
|
// For each line in the reference file, create a tree object
|
|
|
|
if (referenceFiles.size() != 1) {
|
|
|
|
throw runtime_error("HWCM only supports a single reference");
|
|
|
|
}
|
|
|
|
m_ref_trees.clear();
|
|
|
|
m_ref_hwc.clear();
|
|
|
|
ifstream in((referenceFiles[0] + ".trees").c_str());
|
|
|
|
if (!in) {
|
|
|
|
throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
|
|
|
|
}
|
|
|
|
string line;
|
|
|
|
while (getline(in,line)) {
|
|
|
|
line = this->preprocessSentence(line);
|
|
|
|
TreePointer tree (boost::make_shared<InternalTree>(line));
|
|
|
|
m_ref_trees.push_back(tree);
|
|
|
|
vector<map<string, int> > hwc (kHwcmOrder);
|
|
|
|
vector<string> history(kHwcmOrder);
|
|
|
|
extractHeadWordChain(tree, history, hwc);
|
|
|
|
m_ref_hwc.push_back(hwc);
|
|
|
|
vector<int> totals(kHwcmOrder);
|
|
|
|
for (size_t i = 0; i < kHwcmOrder; i++) {
|
|
|
|
for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
|
|
|
|
totals[i] += it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
m_ref_lengths.push_back(totals);
|
|
|
|
}
|
|
|
|
TRACE_ERR(endl);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2015-01-14 14:07:42 +03:00
|
|
|
void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
|
|
|
|
{
|
2014-09-16 14:12:14 +04:00
|
|
|
|
|
|
|
if (tree->GetLength() > 0) {
|
|
|
|
string head = getHead(tree);
|
|
|
|
|
|
|
|
if (head.empty()) {
|
|
|
|
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
|
|
|
|
extractHeadWordChain(*it, history, hwc);
|
|
|
|
}
|
2015-01-14 14:07:42 +03:00
|
|
|
} else {
|
2014-09-16 14:12:14 +04:00
|
|
|
vector<string> new_history(kHwcmOrder);
|
|
|
|
new_history[0] = head;
|
|
|
|
hwc[0][head]++;
|
|
|
|
for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
|
|
|
|
if (!history[hist_idx].empty()) {
|
|
|
|
string chain = history[hist_idx] + " " + head;
|
|
|
|
hwc[hist_idx+1][chain]++;
|
|
|
|
if (hist_idx+2 < kHwcmOrder) {
|
|
|
|
new_history[hist_idx+1] = chain;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
|
|
|
|
extractHeadWordChain(*it, new_history, hwc);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-01-14 14:07:42 +03:00
|
|
|
string HwcmScorer::getHead(TreePointer tree)
|
|
|
|
{
|
2014-09-16 14:12:14 +04:00
|
|
|
// assumption (only true for dependency parse: each constituent has a preterminal label, and corresponding terminal is head)
|
|
|
|
// if constituent has multiple preterminals, first one is picked; if it has no preterminals, empty string is returned
|
2015-01-14 14:07:42 +03:00
|
|
|
for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
|
2014-09-16 14:12:14 +04:00
|
|
|
TreePointer child = *it;
|
|
|
|
|
|
|
|
if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
|
|
|
|
return child->GetChildren()[0]->GetLabel();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return "";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
|
|
|
|
{
|
|
|
|
if (sid >= m_ref_trees.size()) {
|
|
|
|
stringstream msg;
|
|
|
|
msg << "Sentence id (" << sid << ") not found in reference set";
|
|
|
|
throw runtime_error(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
string sentence = this->preprocessSentence(text);
|
|
|
|
|
|
|
|
// if sentence has '|||', assume that tree is in second position (n-best-list);
|
|
|
|
// otherwise, assume it is in first position (calling 'evaluate' with tree as reference)
|
|
|
|
util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
|
|
|
|
++it;
|
|
|
|
if (it) {
|
|
|
|
sentence = it->as_string();
|
|
|
|
}
|
|
|
|
|
|
|
|
TreePointer tree (boost::make_shared<InternalTree>(sentence));
|
|
|
|
vector<map<string, int> > hwc_test (kHwcmOrder);
|
|
|
|
vector<string> history(kHwcmOrder);
|
|
|
|
extractHeadWordChain(tree, history, hwc_test);
|
|
|
|
|
|
|
|
ostringstream stats;
|
|
|
|
for (size_t i = 0; i < kHwcmOrder; i++) {
|
|
|
|
int correct = 0;
|
|
|
|
int test_total = 0;
|
|
|
|
for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
|
|
|
|
test_total += it->second;
|
|
|
|
map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
|
|
|
|
if (it2 != m_ref_hwc[sid][i].end()) {
|
|
|
|
correct += std::min(it->second, it2->second);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
|
|
|
|
}
|
|
|
|
|
|
|
|
string stats_str = stats.str();
|
|
|
|
entry.set(stats_str);
|
|
|
|
}
|
|
|
|
|
2014-09-16 19:36:45 +04:00
|
|
|
float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
|
2014-09-16 14:12:14 +04:00
|
|
|
{
|
|
|
|
float precision = 0;
|
|
|
|
float recall = 0;
|
|
|
|
for (size_t i = 0; i < kHwcmOrder; i++) {
|
|
|
|
float matches = comps[i*3];
|
|
|
|
float test_total = comps[1+(i*3)];
|
|
|
|
float ref_total = comps[2+(i*3)];
|
|
|
|
if (test_total > 0) {
|
|
|
|
precision += matches/test_total;
|
|
|
|
}
|
|
|
|
if (ref_total > 0) {
|
|
|
|
recall += matches/ref_total;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
precision /= (float)kHwcmOrder;
|
|
|
|
recall /= (float)kHwcmOrder;
|
|
|
|
return (2*precision*recall)/(precision+recall); // f1-score
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|