Ongoing moses/phrase-extract refactoring

This commit is contained in:
Phil Williams 2015-06-03 10:33:46 +01:00
parent 2e21f051f2
commit 5e09d3dc71
5 changed files with 17 additions and 85 deletions

View File

@ -32,9 +32,6 @@ class SyntaxNode
protected:
int m_start, m_end;
std::string m_label;
std::vector< SyntaxNode* > m_children;
SyntaxNode* m_parent;
float m_pcfgScore;
public:
typedef std::map<std::string, std::string> AttributeMap;
@ -43,9 +40,7 @@ public:
SyntaxNode( int startPos, int endPos, std::string label )
:m_start(startPos)
,m_end(endPos)
,m_label(label)
,m_parent(0)
,m_pcfgScore(0.0f) {
,m_label(label) {
}
int GetStart() const {
return m_start;
@ -56,24 +51,6 @@ public:
std::string GetLabel() const {
return m_label;
}
float GetPcfgScore() const {
return m_pcfgScore;
}
void SetPcfgScore(float score) {
m_pcfgScore = score;
}
SyntaxNode *GetParent() {
return m_parent;
}
void SetParent(SyntaxNode *parent) {
m_parent = parent;
}
void AddChild(SyntaxNode* child) {
m_children.push_back(child);
}
const std::vector< SyntaxNode* > &GetChildren() const {
return m_children;
}
};
} // namespace MosesTraining

View File

@ -33,7 +33,6 @@ SyntaxNodeCollection::~SyntaxNodeCollection()
void SyntaxNodeCollection::Clear()
{
m_top = 0;
// loop through all m_nodes, delete them
for(size_t i=0; i<m_nodes.size(); i++) {
delete m_nodes[i];
@ -110,48 +109,6 @@ const std::vector< SyntaxNode* >& SyntaxNodeCollection::GetNodes( int startPos,
return endIndex->second;
}
void SyntaxNodeCollection::ConnectNodes()
{
typedef SyntaxTreeIndex2::const_reverse_iterator InnerIterator;
SyntaxNode *prev = 0;
// Iterate over all start indices from lowest to highest.
for (SyntaxTreeIndexIterator p = m_index.begin(); p != m_index.end(); ++p) {
const SyntaxTreeIndex2 &inner = p->second;
// Iterate over all end indices from highest to lowest.
for (InnerIterator q = inner.rbegin(); q != inner.rend(); ++q) {
const std::vector<SyntaxNode*> &nodes = q->second;
// Iterate over all nodes that cover the same span in order of tree
// depth, top-most first.
for (std::vector<SyntaxNode*>::const_reverse_iterator r = nodes.rbegin();
r != nodes.rend(); ++r) {
SyntaxNode *node = *r;
if (!prev) {
// node is the root.
m_top = node;
node->SetParent(0);
} else if (prev->GetStart() == node->GetStart()) {
// prev is the parent of node.
assert(prev->GetEnd() >= node->GetEnd());
node->SetParent(prev);
prev->AddChild(node);
} else {
// prev is a descendant of node's parent. The lowest common
// ancestor of prev and node will be node's parent.
SyntaxNode *ancestor = prev->GetParent();
while (ancestor->GetEnd() < node->GetEnd()) {
ancestor = ancestor->GetParent();
}
assert(ancestor);
node->SetParent(ancestor);
ancestor->AddChild(node);
}
prev = node;
}
}
}
}
std::auto_ptr<SyntaxTree> SyntaxNodeCollection::ExtractTree()
{
std::map<SyntaxNode *, SyntaxTree *> nodeToTree;

View File

@ -38,7 +38,6 @@ class SyntaxNodeCollection
{
protected:
std::vector< SyntaxNode* > m_nodes;
SyntaxNode* m_top;
typedef std::map< int, std::vector< SyntaxNode* > > SyntaxTreeIndex2;
typedef SyntaxTreeIndex2::const_iterator SyntaxTreeIndexIterator2;
@ -49,18 +48,12 @@ protected:
std::vector< SyntaxNode* > m_emptyNode;
public:
SyntaxNodeCollection()
: m_top(0) // m_top doesn't get set unless ConnectNodes is called.
, m_size(0) {}
SyntaxNodeCollection() : m_size(0) {}
~SyntaxNodeCollection();
SyntaxNode *AddNode( int startPos, int endPos, const std::string &label );
SyntaxNode *GetTop() {
return m_top;
}
ParentNodes Parse();
bool HasNode( int startPos, int endPos ) const;
const std::vector< SyntaxNode* >& GetNodes( int startPos, int endPos ) const;
@ -70,7 +63,6 @@ public:
size_t GetNumWords() const {
return m_size;
}
void ConnectNodes();
void Clear();
std::auto_ptr<SyntaxTree> ExtractTree();

View File

@ -398,10 +398,6 @@ bool ProcessAndStripXMLTags(string &line, SyntaxNodeCollection &nodeCollection,
string label = ParseXmlTagAttribute(tagContent,"label");
labelCollection.insert( label );
string pcfgString = ParseXmlTagAttribute(tagContent,"pcfg");
float pcfgScore = pcfgString == "" ? 0.0f
: std::atof(pcfgString.c_str());
// report what we have processed so far
if (0) {
cerr << "XML TAG NAME IS: '" << tagName << "'" << endl;
@ -409,7 +405,6 @@ bool ProcessAndStripXMLTags(string &line, SyntaxNodeCollection &nodeCollection,
cerr << "XML SPAN IS: " << startPos << "-" << (endPos-1) << endl;
}
SyntaxNode *node = nodeCollection.AddNode( startPos, endPos-1, label );
node->SetPcfgScore(pcfgScore);
ParseXmlTagAttributes(tagContent, node->attributes);
}
}

View File

@ -110,6 +110,8 @@ void collectWordLabelCounts(SentenceAlignmentWithSyntax &sentence );
void writeGlueGrammar(const string &, RuleExtractionOptions &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection);
void writeUnknownWordLabel(const string &);
double getPcfgScore(const SyntaxNode &);
int main(int argc, char* argv[])
{
@ -564,8 +566,7 @@ string ExtractTask::saveTargetHieroPhrase( int startT, int endT, int startS, int
}
if (m_options.pcfgScore) {
double score = m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->GetPcfgScore();
logPCFGScore -= score;
logPCFGScore -= getPcfgScore(*m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]);
}
currPos = hole.GetEnd(1);
@ -689,7 +690,7 @@ void ExtractTask::saveHieroPhrase( int startT, int endT, int startS, int endS
// target
if (m_options.pcfgScore) {
double logPCFGScore = m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0]]->GetPcfgScore();
double logPCFGScore = getPcfgScore(*m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0]]);
rule.target = saveTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex, logPCFGScore, countS)
+ " [" + targetLabel + "]";
rule.pcfgScore = std::exp(logPCFGScore);
@ -973,7 +974,7 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, int count
rule.target += "[" + targetLabel + "]";
if (m_options.pcfgScore) {
double logPCFGScore = m_sentence.targetTree.GetNodes(startT,endT)[0]->GetPcfgScore();
double logPCFGScore = getPcfgScore(*m_sentence.targetTree.GetNodes(startT,endT)[0]);
rule.pcfgScore = std::exp(logPCFGScore);
}
@ -1194,3 +1195,13 @@ void writeUnknownWordLabel(const string & fileName)
outFile.close();
}
double getPcfgScore(const SyntaxNode &node)
{
double score = 0.0f;
SyntaxNode::AttributeMap::const_iterator p = node.attributes.find("pcfg");
if (p != node.attributes.end()) {
score = std::atof(p->second.c_str());
}
return score;
}