Ongoing moses/phrase-extract refactoring

This commit is contained in:
Phil Williams 2015-06-03 11:10:45 +01:00
parent 5e09d3dc71
commit ed321791a7
13 changed files with 62 additions and 75 deletions

View File

@ -20,37 +20,23 @@
#pragma once
#include <map>
#include <sstream>
#include <string>
#include <vector>
namespace MosesTraining
{
namespace MosesTraining {
class SyntaxNode
{
protected:
int m_start, m_end;
std::string m_label;
public:
struct SyntaxNode {
typedef std::map<std::string, std::string> AttributeMap;
AttributeMap attributes;
SyntaxNode(const std::string &label_, int start_, int end_)
: label(label_)
, start(start_)
, end(end_) {
}
SyntaxNode( int startPos, int endPos, std::string label )
:m_start(startPos)
,m_end(endPos)
,m_label(label) {
}
int GetStart() const {
return m_start;
}
int GetEnd() const {
return m_end;
}
std::string GetLabel() const {
return m_label;
}
std::string label;
int start;
int end;
AttributeMap attributes;
};
} // namespace MosesTraining

View File

@ -44,7 +44,7 @@ void SyntaxNodeCollection::Clear()
SyntaxNode *SyntaxNodeCollection::AddNode(int startPos, int endPos,
const std::string &label)
{
SyntaxNode* newNode = new SyntaxNode( startPos, endPos, label );
SyntaxNode* newNode = new SyntaxNode(label, startPos, endPos);
m_nodes.push_back( newNode );
m_index[ startPos ][ endPos ].push_back( newNode );
m_size = std::max(endPos+1, m_size);
@ -141,16 +141,16 @@ std::auto_ptr<SyntaxTree> SyntaxNodeCollection::ExtractTree()
// node is the root.
root = tree;
tree->parent() = 0;
} else if (prevNode->GetStart() == node->GetStart()) {
} else if (prevNode->start == node->start) {
// prevNode is the parent of node.
assert(prevNode->GetEnd() >= node->GetEnd());
assert(prevNode->end >= node->end);
tree->parent() = prevTree;
prevTree->children().push_back(tree);
} else {
// prevNode is a descendant of node's parent. The lowest common
// ancestor of prevNode and node will be node's parent.
SyntaxTree *ancestor = prevTree->parent();
while (ancestor->value().GetEnd() < tree->value().GetEnd()) {
while (ancestor->value().end < tree->value().end) {
ancestor = ancestor->parent();
}
assert(ancestor);

View File

@ -419,7 +419,7 @@ bool ProcessAndStripXMLTags(string &line, SyntaxNodeCollection &nodeCollection,
const vector< SyntaxNode* >& topNodes = nodeCollection.GetNodes( 0, wordPos-1 );
for( vector< SyntaxNode* >::const_iterator node = topNodes.begin(); node != topNodes.end(); node++ ) {
SyntaxNode *n = *node;
const string &label = n->GetLabel();
const string &label = n->label;
if (topLabelCollection.find( label ) == topLabelCollection.end())
topLabelCollection[ label ] = 0;
topLabelCollection[ label ]++;

View File

@ -21,6 +21,7 @@
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <memory>
#include <stack>
@ -213,7 +214,7 @@ Node *AlignmentGraph::CopyParseTree(const SyntaxTree *root)
{
NodeType nodeType = (root->IsLeaf()) ? TARGET : TREE;
std::auto_ptr<Node> n(new Node(root->value().GetLabel(), nodeType));
std::auto_ptr<Node> n(new Node(root->value().label, nodeType));
if (nodeType == TREE) {
float score = 0.0f;

View File

@ -813,7 +813,7 @@ void ExtractGHKM::CollectWordLabelCounts(
for (SyntaxTree::ConstLeafIterator p(root);
p != SyntaxTree::ConstLeafIterator(); ++p) {
const SyntaxTree &leaf = *p;
const std::string &word = leaf.value().GetLabel();
const std::string &word = leaf.value().label;
const SyntaxTree *ancestor = leaf.parent();
// If unary rule elimination is enabled and this word is at the end of a
// chain of unary rewrites, e.g.
@ -825,7 +825,7 @@ void ExtractGHKM::CollectWordLabelCounts(
ancestor->parent()->children().size() == 1) {
ancestor = ancestor->parent();
}
const std::string &label = ancestor->value().GetLabel();
const std::string &label = ancestor->value().label;
++wordCount[word];
wordLabel[word] = label;
}
@ -837,7 +837,7 @@ std::vector<std::string> ExtractGHKM::ReadTokens(const SyntaxTree &root) const
for (SyntaxTree::ConstLeafIterator p(root);
p != SyntaxTree::ConstLeafIterator(); ++p) {
const SyntaxTree &leaf = *p;
const std::string &word = leaf.value().GetLabel();
const std::string &word = leaf.value().label;
tokens.push_back(word);
}
return tokens;

View File

@ -144,7 +144,7 @@ void ScfgRule::PushSourceLabel(const SyntaxNodeCollection *sourceNodeCollection,
sourceNodeCollection->GetNodes(span.first,span.second);
if (!sourceLabels.empty()) {
// store the topmost matching label from the source syntax tree
m_sourceLabels.push_back(sourceLabels.back()->GetLabel());
m_sourceLabels.push_back(sourceLabels.back()->label);
}
} else {
// no matching source-side syntactic constituent: store nonMatchingLabel

View File

@ -507,7 +507,7 @@ void ExtractTask::preprocessSourceHieroPhrase( int startT, int endT, int startS,
int labelI = labelIndex[ 2+holeCount+holeTotal ];
string label = m_options.sourceSyntax ?
m_sentence.sourceTree.GetNodes(currPos,hole.GetEnd(0))[ labelI ]->GetLabel() : "X";
m_sentence.sourceTree.GetNodes(currPos,hole.GetEnd(0))[ labelI ]->label : "X";
hole.SetLabel(label, 0);
currPos = hole.GetEnd(0);
@ -550,7 +550,7 @@ string ExtractTask::saveTargetHieroPhrase( int startT, int endT, int startS, int
int labelI = labelIndex[ 2+holeCount ];
string targetLabel;
if (m_options.targetSyntax) {
targetLabel = m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->GetLabel();
targetLabel = m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->label;
} else if (m_options.boundaryRules && (startS == 0 || endS == countS - 1)) {
targetLabel = "S";
} else {
@ -675,7 +675,7 @@ void ExtractTask::saveHieroPhrase( int startT, int endT, int startS, int endS
// phrase labels
string targetLabel;
if (m_options.targetSyntax) {
targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0] ]->GetLabel();
targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0] ]->label;
} else if (m_options.boundaryRules && (startS == 0 || endS == countS - 1)) {
targetLabel = "S";
} else {
@ -683,7 +683,7 @@ void ExtractTask::saveHieroPhrase( int startT, int endT, int startS, int endS
}
string sourceLabel = m_options.sourceSyntax ?
m_sentence.sourceTree.GetNodes(startS,endS)[ labelIndex[1] ]->GetLabel() : "X";
m_sentence.sourceTree.GetNodes(startS,endS)[ labelIndex[1] ]->label : "X";
// create non-terms on the source side
preprocessSourceHieroPhrase(startT, endT, startS, endS, indexS, holeColl, labelIndex);
@ -947,13 +947,13 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, int count
// phrase labels
string targetLabel,sourceLabel;
if (m_options.targetSyntax && m_options.conditionOnTargetLhs) {
sourceLabel = targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[0]->GetLabel();
sourceLabel = targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[0]->label;
} else {
sourceLabel = m_options.sourceSyntax ?
m_sentence.sourceTree.GetNodes(startS,endS)[0]->GetLabel() : "X";
m_sentence.sourceTree.GetNodes(startS,endS)[0]->label : "X";
if (m_options.targetSyntax) {
targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[0]->GetLabel();
targetLabel = m_sentence.targetTree.GetNodes(startT,endT)[0]->label;
} else if (m_options.boundaryRules && (startS == 0 || endS == countS - 1)) {
targetLabel = "S";
} else {
@ -1166,7 +1166,7 @@ void collectWordLabelCounts( SentenceAlignmentWithSyntax &sentence )
const vector< SyntaxNode* >& labels = sentence.targetTree.GetNodes(ti,ti);
if (labels.size() > 0) {
wordCount[ word ]++;
wordLabel[ word ] = labels[0]->GetLabel();
wordLabel[ word ] = labels[0]->label;
}
}
}

View File

@ -27,7 +27,7 @@ TreeTsgFilter::TreeTsgFilter(
TreeTsgFilter::IdTree *TreeTsgFilter::SyntaxTreeToIdTree(const SyntaxTree &s)
{
IdTree *t = new IdTree(m_testVocab.Insert(s.value().GetLabel()));
IdTree *t = new IdTree(m_testVocab.Insert(s.value().label));
const std::vector<SyntaxTree*> &sChildren = s.children();
std::vector<IdTree*> &tChildren = t->children();
tChildren.reserve(sChildren.size());

View File

@ -37,7 +37,7 @@ void RuleExtractor::Extract(const SyntaxTree &tree, RuleCollection &rc) const
return;
}
std::size_t lhs = non_term_vocab_.Insert(tree.value().GetLabel());
std::size_t lhs = non_term_vocab_.Insert(tree.value().label);
std::vector<std::size_t> rhs;
const std::vector<SyntaxTree *> &children = tree.children();
@ -45,7 +45,7 @@ void RuleExtractor::Extract(const SyntaxTree &tree, RuleCollection &rc) const
for (std::vector<SyntaxTree *>::const_iterator p(children.begin());
p != children.end(); ++p) {
const SyntaxTree &child = **p;
rhs.push_back(non_term_vocab_.Insert(child.value().GetLabel()));
rhs.push_back(non_term_vocab_.Insert(child.value().label));
Extract(child, rc);
}
rc.Add(lhs, rhs);

View File

@ -58,13 +58,13 @@ bool TreeScorer::CalcScores(SyntaxTree &root)
std::vector<std::size_t> key;
key.reserve(children.size()+1);
key.push_back(non_term_vocab_.Lookup(root.value().GetLabel()));
key.push_back(non_term_vocab_.Lookup(root.value().label));
for (std::vector<SyntaxTree *>::const_iterator p(children.begin());
p != children.end(); ++p) {
SyntaxTree *child = *p;
assert(!child->IsLeaf());
key.push_back(non_term_vocab_.Lookup(child->value().GetLabel()));
key.push_back(non_term_vocab_.Lookup(child->value().label));
if (!CalcScores(*child)) {
return false;
}

View File

@ -118,9 +118,9 @@ void store( SyntaxNodeCollection &tree, const vector< string > &words )
// output tree nodes
vector< SyntaxNode* > nodes = tree.GetAllNodes();
for( size_t i=0; i<nodes.size(); i++ ) {
cout << " <tree span=\"" << nodes[i]->GetStart()
<< "-" << nodes[i]->GetEnd()
<< "\" label=\"" << nodes[i]->GetLabel()
cout << " <tree span=\"" << nodes[i]->start
<< "-" << nodes[i]->end
<< "\" label=\"" << nodes[i]->label
<< "\"/>";
}
cout << endl;
@ -133,7 +133,7 @@ void LeftBinarize( SyntaxNodeCollection &tree, ParentNodes &parents )
if (point.size() > 3) {
const vector< SyntaxNode* >& topNodes
= tree.GetNodes( point[0], point[point.size()-1]-1);
string topLabel = topNodes[0]->GetLabel();
string topLabel = topNodes[0]->label;
for(size_t i=2; i<point.size()-1; i++) {
// cerr << "LeftBin " << point[0] << "-" << (point[point.size()-1]-1) << ": " << point[0] << "-" << point[i]-1 << " ^" << topLabel << endl;
@ -151,7 +151,7 @@ void RightBinarize( SyntaxNodeCollection &tree, ParentNodes &parents )
int endPoint = point[point.size()-1]-1;
const vector< SyntaxNode* >& topNodes
= tree.GetNodes( point[0], endPoint);
string topLabel = topNodes[0]->GetLabel();
string topLabel = topNodes[0]->label;
for(size_t i=1; i<point.size()-2; i++) {
// cerr << "RightBin " << point[0] << "-" << (point[point.size()-1]-1) << ": " << point[i] << "-" << endPoint << " ^" << topLabel << endl;
@ -178,29 +178,29 @@ void SAMT( SyntaxNodeCollection &tree, ParentNodes &parents )
// cerr << endl;
for(size_t i = 0; i+2 < point.size(); i++) {
// cerr << "\tadding " << point[i] << ";" << point[i+1] << ";" << (point[i+2]-1) << ": " << tree.GetNodes(point[i ],point[i+1]-1)[0]->GetLabel() << "+" << tree.GetNodes(point[i+1],point[i+2]-1)[0]->GetLabel() << endl;
// cerr << "\tadding " << point[i] << ";" << point[i+1] << ";" << (point[i+2]-1) << ": " << tree.GetNodes(point[i ],point[i+1]-1)[0]->label << "+" << tree.GetNodes(point[i+1],point[i+2]-1)[0]->label << endl;
newTree.AddNode( point[i],point[i+2]-1,
tree.GetNodes(point[i ],point[i+1]-1)[0]->GetLabel()
tree.GetNodes(point[i ],point[i+1]-1)[0]->label
+ "+" +
tree.GetNodes(point[i+1],point[i+2]-1)[0]->GetLabel() );
tree.GetNodes(point[i+1],point[i+2]-1)[0]->label);
}
}
if (point.size() >= 4) {
int ps = point.size();
string topLabel = tree.GetNodes(point[0],point[ps-1]-1)[0]->GetLabel();
string topLabel = tree.GetNodes(point[0],point[ps-1]-1)[0]->label;
// cerr << "\tadding " << topLabel + "\\" + tree.GetNodes(point[0],point[1]-1)[0]->GetLabel() << endl;
// cerr << "\tadding " << topLabel + "\\" + tree.GetNodes(point[0],point[1]-1)[0]->label << endl;
newTree.AddNode( point[1],point[ps-1]-1,
topLabel
+ "\\" +
tree.GetNodes(point[0],point[1]-1)[0]->GetLabel() );
tree.GetNodes(point[0],point[1]-1)[0]->label );
// cerr << "\tadding " << topLabel + "/" + tree.GetNodes(point[ps-2],point[ps-1]-1)[0]->GetLabel() << endl;
// cerr << "\tadding " << topLabel + "/" + tree.GetNodes(point[ps-2],point[ps-1]-1)[0]->label << endl;
newTree.AddNode( point[0],point[ps-2]-1,
topLabel
+ "/" +
tree.GetNodes(point[ps-2],point[ps-1]-1)[0]->GetLabel() );
tree.GetNodes(point[ps-2],point[ps-1]-1)[0]->label );
}
}
@ -219,12 +219,12 @@ void SAMT( SyntaxNodeCollection &tree, ParentNodes &parents )
for(int mid=start+1; mid<=end && !done; mid++) {
if (tree.HasNode(start,mid-1) && tree.HasNode(mid,end)) {
// cerr << "\tadding " << tree.GetNodes(start,mid-1)[0]->GetLabel() << "++" << tree.GetNodes(mid, end )[0]->GetLabel() << endl;
// cerr << "\tadding " << tree.GetNodes(start,mid-1)[0]->label << "++" << tree.GetNodes(mid, end )[0]->label << endl;
newTree.AddNode( start, end,
tree.GetNodes(start,mid-1)[0]->GetLabel()
tree.GetNodes(start,mid-1)[0]->label
+ "++" +
tree.GetNodes(mid, end )[0]->GetLabel() );
tree.GetNodes(mid, end )[0]->label );
done = true;
}
}
@ -234,9 +234,9 @@ void SAMT( SyntaxNodeCollection &tree, ParentNodes &parents )
for(int postEnd=end+1; postEnd<numWords && !done; postEnd++) {
if (tree.HasNode(start,postEnd) && tree.HasNode(end+1,postEnd)) {
newTree.AddNode( start, end,
tree.GetNodes(start,postEnd)[0]->GetLabel()
tree.GetNodes(start,postEnd)[0]->label
+ "//" +
tree.GetNodes(end+1,postEnd)[0]->GetLabel() );
tree.GetNodes(end+1,postEnd)[0]->label );
done = true;
}
}
@ -245,11 +245,11 @@ void SAMT( SyntaxNodeCollection &tree, ParentNodes &parents )
// if matching a constituent A left-minus constituent B: use A\\B
for(int preStart=start-1; preStart>=0; preStart--) {
if (tree.HasNode(preStart,end) && tree.HasNode(preStart,start-1)) {
// cerr << "\tadding " << tree.GetNodes(preStart,end )[0]->GetLabel() << "\\\\" <<tree.GetNodes(preStart,start-1)[0]->GetLabel() << endl;
// cerr << "\tadding " << tree.GetNodes(preStart,end )[0]->label << "\\\\" <<tree.GetNodes(preStart,start-1)[0]->label << endl;
newTree.AddNode( start, end,
tree.GetNodes(preStart,end )[0]->GetLabel()
tree.GetNodes(preStart,end )[0]->label
+ "\\\\" +
tree.GetNodes(preStart,start-1)[0]->GetLabel() );
tree.GetNodes(preStart,start-1)[0]->label );
done = true;
}
}
@ -268,6 +268,6 @@ void SAMT( SyntaxNodeCollection &tree, ParentNodes &parents )
// adding all new nodes
vector< SyntaxNode* > nodes = newTree.GetAllNodes();
for( size_t i=0; i<nodes.size(); i++ ) {
tree.AddNode( nodes[i]->GetStart(), nodes[i]->GetEnd(), nodes[i]->GetLabel());
tree.AddNode( nodes[i]->start, nodes[i]->end, nodes[i]->label);
}
}

View File

@ -47,15 +47,15 @@ void XmlTreeParser::AttachWords(const std::vector<std::string> &words,
for (std::vector<SyntaxTree*>::iterator p = leaves.begin(); p != leaves.end();
++p) {
SyntaxTree *leaf = *p;
const int start = leaf->value().GetStart();
const int end = leaf->value().GetEnd();
const int start = leaf->value().start;
const int end = leaf->value().end;
if (start != end) {
std::ostringstream msg;
msg << "leaf node covers multiple words (" << start << "-" << end
<< "): this is currently unsupported";
throw Exception(msg.str());
}
SyntaxTree *newLeaf = new SyntaxTree(SyntaxNode(start, end, *q++));
SyntaxTree *newLeaf = new SyntaxTree(SyntaxNode(*q++, start, end));
leaf->children().push_back(newLeaf);
newLeaf->parent() = leaf;
}

View File

@ -16,7 +16,7 @@ void XmlTreeWriter::Write(const SyntaxTree &tree) const {
assert(!tree.IsLeaf());
// Opening tag
out_ << "<tree label=\"" << Escape(tree.value().GetLabel()) << "\"";
out_ << "<tree label=\"" << Escape(tree.value().label) << "\"";
for (SyntaxNode::AttributeMap::const_iterator
p = tree.value().attributes.begin();
p != tree.value().attributes.end(); ++p) {
@ -31,7 +31,7 @@ void XmlTreeWriter::Write(const SyntaxTree &tree) const {
p != tree.children().end(); ++p) {
SyntaxTree &child = **p;
if (child.IsLeaf()) {
out_ << " " << Escape(child.value().GetLabel());
out_ << " " << Escape(child.value().label);
} else {
out_ << " ";
Write(child);