mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
Ongoing moses/phrase-extract refactoring
This commit is contained in:
parent
2e21f051f2
commit
5e09d3dc71
@ -32,9 +32,6 @@ class SyntaxNode
|
||||
protected:
|
||||
int m_start, m_end;
|
||||
std::string m_label;
|
||||
std::vector< SyntaxNode* > m_children;
|
||||
SyntaxNode* m_parent;
|
||||
float m_pcfgScore;
|
||||
public:
|
||||
typedef std::map<std::string, std::string> AttributeMap;
|
||||
|
||||
@ -43,9 +40,7 @@ public:
|
||||
SyntaxNode( int startPos, int endPos, std::string label )
|
||||
:m_start(startPos)
|
||||
,m_end(endPos)
|
||||
,m_label(label)
|
||||
,m_parent(0)
|
||||
,m_pcfgScore(0.0f) {
|
||||
,m_label(label) {
|
||||
}
|
||||
int GetStart() const {
|
||||
return m_start;
|
||||
@ -56,24 +51,6 @@ public:
|
||||
std::string GetLabel() const {
|
||||
return m_label;
|
||||
}
|
||||
float GetPcfgScore() const {
|
||||
return m_pcfgScore;
|
||||
}
|
||||
void SetPcfgScore(float score) {
|
||||
m_pcfgScore = score;
|
||||
}
|
||||
SyntaxNode *GetParent() {
|
||||
return m_parent;
|
||||
}
|
||||
void SetParent(SyntaxNode *parent) {
|
||||
m_parent = parent;
|
||||
}
|
||||
void AddChild(SyntaxNode* child) {
|
||||
m_children.push_back(child);
|
||||
}
|
||||
const std::vector< SyntaxNode* > &GetChildren() const {
|
||||
return m_children;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace MosesTraining
|
||||
|
@ -33,7 +33,6 @@ SyntaxNodeCollection::~SyntaxNodeCollection()
|
||||
|
||||
void SyntaxNodeCollection::Clear()
|
||||
{
|
||||
m_top = 0;
|
||||
// loop through all m_nodes, delete them
|
||||
for(size_t i=0; i<m_nodes.size(); i++) {
|
||||
delete m_nodes[i];
|
||||
@ -110,48 +109,6 @@ const std::vector< SyntaxNode* >& SyntaxNodeCollection::GetNodes( int startPos,
|
||||
return endIndex->second;
|
||||
}
|
||||
|
||||
void SyntaxNodeCollection::ConnectNodes()
|
||||
{
|
||||
typedef SyntaxTreeIndex2::const_reverse_iterator InnerIterator;
|
||||
|
||||
SyntaxNode *prev = 0;
|
||||
// Iterate over all start indices from lowest to highest.
|
||||
for (SyntaxTreeIndexIterator p = m_index.begin(); p != m_index.end(); ++p) {
|
||||
const SyntaxTreeIndex2 &inner = p->second;
|
||||
// Iterate over all end indices from highest to lowest.
|
||||
for (InnerIterator q = inner.rbegin(); q != inner.rend(); ++q) {
|
||||
const std::vector<SyntaxNode*> &nodes = q->second;
|
||||
// Iterate over all nodes that cover the same span in order of tree
|
||||
// depth, top-most first.
|
||||
for (std::vector<SyntaxNode*>::const_reverse_iterator r = nodes.rbegin();
|
||||
r != nodes.rend(); ++r) {
|
||||
SyntaxNode *node = *r;
|
||||
if (!prev) {
|
||||
// node is the root.
|
||||
m_top = node;
|
||||
node->SetParent(0);
|
||||
} else if (prev->GetStart() == node->GetStart()) {
|
||||
// prev is the parent of node.
|
||||
assert(prev->GetEnd() >= node->GetEnd());
|
||||
node->SetParent(prev);
|
||||
prev->AddChild(node);
|
||||
} else {
|
||||
// prev is a descendant of node's parent. The lowest common
|
||||
// ancestor of prev and node will be node's parent.
|
||||
SyntaxNode *ancestor = prev->GetParent();
|
||||
while (ancestor->GetEnd() < node->GetEnd()) {
|
||||
ancestor = ancestor->GetParent();
|
||||
}
|
||||
assert(ancestor);
|
||||
node->SetParent(ancestor);
|
||||
ancestor->AddChild(node);
|
||||
}
|
||||
prev = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::auto_ptr<SyntaxTree> SyntaxNodeCollection::ExtractTree()
|
||||
{
|
||||
std::map<SyntaxNode *, SyntaxTree *> nodeToTree;
|
||||
|
@ -38,7 +38,6 @@ class SyntaxNodeCollection
|
||||
{
|
||||
protected:
|
||||
std::vector< SyntaxNode* > m_nodes;
|
||||
SyntaxNode* m_top;
|
||||
|
||||
typedef std::map< int, std::vector< SyntaxNode* > > SyntaxTreeIndex2;
|
||||
typedef SyntaxTreeIndex2::const_iterator SyntaxTreeIndexIterator2;
|
||||
@ -49,18 +48,12 @@ protected:
|
||||
std::vector< SyntaxNode* > m_emptyNode;
|
||||
|
||||
public:
|
||||
SyntaxNodeCollection()
|
||||
: m_top(0) // m_top doesn't get set unless ConnectNodes is called.
|
||||
, m_size(0) {}
|
||||
SyntaxNodeCollection() : m_size(0) {}
|
||||
|
||||
~SyntaxNodeCollection();
|
||||
|
||||
SyntaxNode *AddNode( int startPos, int endPos, const std::string &label );
|
||||
|
||||
SyntaxNode *GetTop() {
|
||||
return m_top;
|
||||
}
|
||||
|
||||
ParentNodes Parse();
|
||||
bool HasNode( int startPos, int endPos ) const;
|
||||
const std::vector< SyntaxNode* >& GetNodes( int startPos, int endPos ) const;
|
||||
@ -70,7 +63,6 @@ public:
|
||||
size_t GetNumWords() const {
|
||||
return m_size;
|
||||
}
|
||||
void ConnectNodes();
|
||||
void Clear();
|
||||
|
||||
std::auto_ptr<SyntaxTree> ExtractTree();
|
||||
|
@ -398,10 +398,6 @@ bool ProcessAndStripXMLTags(string &line, SyntaxNodeCollection &nodeCollection,
|
||||
string label = ParseXmlTagAttribute(tagContent,"label");
|
||||
labelCollection.insert( label );
|
||||
|
||||
string pcfgString = ParseXmlTagAttribute(tagContent,"pcfg");
|
||||
float pcfgScore = pcfgString == "" ? 0.0f
|
||||
: std::atof(pcfgString.c_str());
|
||||
|
||||
// report what we have processed so far
|
||||
if (0) {
|
||||
cerr << "XML TAG NAME IS: '" << tagName << "'" << endl;
|
||||
@ -409,7 +405,6 @@ bool ProcessAndStripXMLTags(string &line, SyntaxNodeCollection &nodeCollection,
|
||||
cerr << "XML SPAN IS: " << startPos << "-" << (endPos-1) << endl;
|
||||
}
|
||||
SyntaxNode *node = nodeCollection.AddNode( startPos, endPos-1, label );
|
||||
node->SetPcfgScore(pcfgScore);
|
||||
ParseXmlTagAttributes(tagContent, node->attributes);
|
||||
}
|
||||
}
|
||||
|
@ -110,6 +110,8 @@ void collectWordLabelCounts(SentenceAlignmentWithSyntax &sentence );
|
||||
void writeGlueGrammar(const string &, RuleExtractionOptions &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection);
|
||||
void writeUnknownWordLabel(const string &);
|
||||
|
||||
double getPcfgScore(const SyntaxNode &);
|
||||
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@ -564,8 +566,7 @@ string ExtractTask::saveTargetHieroPhrase( int startT, int endT, int startS, int
|
||||
}
|
||||
|
||||
if (m_options.pcfgScore) {
|
||||
double score = m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->GetPcfgScore();
|
||||
logPCFGScore -= score;
|
||||
logPCFGScore -= getPcfgScore(*m_sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]);
|
||||
}
|
||||
|
||||
currPos = hole.GetEnd(1);
|
||||
@ -689,7 +690,7 @@ void ExtractTask::saveHieroPhrase( int startT, int endT, int startS, int endS
|
||||
|
||||
// target
|
||||
if (m_options.pcfgScore) {
|
||||
double logPCFGScore = m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0]]->GetPcfgScore();
|
||||
double logPCFGScore = getPcfgScore(*m_sentence.targetTree.GetNodes(startT,endT)[labelIndex[0]]);
|
||||
rule.target = saveTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex, logPCFGScore, countS)
|
||||
+ " [" + targetLabel + "]";
|
||||
rule.pcfgScore = std::exp(logPCFGScore);
|
||||
@ -973,7 +974,7 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, int count
|
||||
rule.target += "[" + targetLabel + "]";
|
||||
|
||||
if (m_options.pcfgScore) {
|
||||
double logPCFGScore = m_sentence.targetTree.GetNodes(startT,endT)[0]->GetPcfgScore();
|
||||
double logPCFGScore = getPcfgScore(*m_sentence.targetTree.GetNodes(startT,endT)[0]);
|
||||
rule.pcfgScore = std::exp(logPCFGScore);
|
||||
}
|
||||
|
||||
@ -1194,3 +1195,13 @@ void writeUnknownWordLabel(const string & fileName)
|
||||
|
||||
outFile.close();
|
||||
}
|
||||
|
||||
double getPcfgScore(const SyntaxNode &node)
|
||||
{
|
||||
double score = 0.0f;
|
||||
SyntaxNode::AttributeMap::const_iterator p = node.attributes.find("pcfg");
|
||||
if (p != node.attributes.end()) {
|
||||
score = std::atof(p->second.c_str());
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user