multi-threaded hierarchical rule extractor

This commit is contained in:
phikoehn 2012-04-17 05:54:48 +01:00
parent 05f02157ab
commit 2c520fb93c
6 changed files with 238 additions and 147 deletions

View File

@ -31,7 +31,7 @@ namespace Moses
{
ThreadPool::ThreadPool( size_t numThreads )
: m_stopped(false), m_stopping(false)
: m_stopped(false), m_stopping(false), m_queueLimit(0)
{
for (size_t i = 0; i < numThreads; ++i) {
m_threads.create_thread(boost::bind(&ThreadPool::Execute,this));
@ -70,6 +70,9 @@ void ThreadPool::Submit( Task* task )
if (m_stopping) {
throw runtime_error("ThreadPool stopping - unable to accept new jobs");
}
if (m_queueLimit > 0 && m_tasks.size() >= m_queueLimit) {
m_threadAvailable.wait(lock);
}
m_tasks.push(task);
m_threadNeeded.notify_all();
@ -97,7 +100,6 @@ void ThreadPool::Stop(bool processRemainingJobs)
}
m_threadNeeded.notify_all();
m_threads.join_all();
}

View File

@ -81,6 +81,11 @@ class ThreadPool
**/
void Stop(bool processRemainingJobs = false);
/**
* Set maximum number of queued threads (otherwise Submit blocks)
**/
void SetQueueLimit( size_t limit ) { m_queueLimit = limit; }
private:
/**
* The main loop executed by each thread.
@ -94,6 +99,7 @@ private:
boost::condition_variable m_threadAvailable;
bool m_stopped;
bool m_stopping;
size_t m_queueLimit;
};
class TestTask : public Task

View File

@ -11,15 +11,21 @@
using namespace std;
void ExtractedRule::OutputNTLengths(std::ostream &out) const
{
ostringstream outString;
OutputNTLengths(outString);
out << outString;
}
void ExtractedRule::OutputNTLengths(std::ostringstream &outString) const
{
std::map<size_t, std::pair<size_t, size_t> >::const_iterator iter;
for (iter = m_ntLengths.begin(); iter != m_ntLengths.end(); ++iter)
{
size_t sourcePos = iter->first;
const std::pair<size_t, size_t> &spanLengths = iter->second;
out << sourcePos << "=" << spanLengths.first << "," <<spanLengths.second << " ";
outString << sourcePos << "=" << spanLengths.first << "," <<spanLengths.second << " ";
}
}
std::ostream& operator<<(std::ostream &out, const ExtractedRule &obj)

View File

@ -23,6 +23,7 @@
#include <string>
#include <iostream>
#include <sstream>
#include <map>
// sentence-level collection of rules
@ -65,6 +66,7 @@ public:
}
void OutputNTLengths(std::ostream &out) const;
void OutputNTLengths(std::ostringstream &out) const;
};
#endif

View File

@ -3,7 +3,7 @@ alias trees : SyntaxTree.cpp XmlTree.cpp : : : <include>. ;
exe extract : tables-core.cpp SentenceAlignment.cpp extract.cpp InputFileStream ;
exe extract-rules : tables-core.cpp SentenceAlignment.cpp SentenceAlignmentWithSyntax.cpp SyntaxTree.cpp XmlTree.cpp HoleCollection.cpp extract-rules.cpp ExtractedRule.cpp InputFileStream ;
exe extract-rules : tables-core.cpp SentenceAlignment.cpp SentenceAlignmentWithSyntax.cpp SyntaxTree.cpp XmlTree.cpp HoleCollection.cpp extract-rules.cpp ExtractedRule.cpp InputFileStream ../../../moses/src//ThreadPool ;
exe extract-lex : extract-lex.cpp InputFileStream ;

View File

@ -45,6 +45,8 @@
#include "tables-core.h"
#include "XmlTree.h"
#include "InputFileStream.h"
#include "../../../moses/src/ThreadPool.h"
#include "../../../moses/src/OutputCollector.h"
#define LINE_MAX_LENGTH 500000
@ -53,20 +55,48 @@ using namespace std;
typedef vector< int > LabelIndex;
typedef map< int, int > WordIndex;
vector< ExtractedRule > extractedRules;
class ExtractTask : public Moses::Task {
private:
size_t m_id;
SentenceAlignmentWithSyntax *m_sentence;
RuleExtractionOptions &m_options;
Moses::OutputCollector* m_extractCollector;
Moses::OutputCollector* m_extractCollectorInv;
void extractRules(SentenceAlignmentWithSyntax & );
public:
ExtractTask(size_t id, SentenceAlignmentWithSyntax *sentence, RuleExtractionOptions &options, Moses::OutputCollector* extractCollector, Moses::OutputCollector* extractCollectorInv):
m_id(id),
m_sentence(sentence),
m_options(options),
m_extractCollector(extractCollector),
m_extractCollectorInv(extractCollectorInv) {}
~ExtractTask() { delete m_sentence; }
void Run();
private:
vector< ExtractedRule > m_extractedRules;
// main functions
void extractRules();
void addRuleToCollection(ExtractedRule &rule);
void consolidateRules();
void writeRulesToFile();
void writeGlueGrammar(const string &);
void collectWordLabelCounts(SentenceAlignmentWithSyntax &sentence );
void writeUnknownWordLabel(const string &);
void addRule( SentenceAlignmentWithSyntax &, int, int, int, int
, RuleExist &ruleExist);
void addHieroRule( SentenceAlignmentWithSyntax &sentence, int startT, int endT, int startS, int endS
// subs
void addRule( int, int, int, int, RuleExist &ruleExist);
void addHieroRule( int startT, int endT, int startS, int endS
, RuleExist &ruleExist, const HoleCollection &holeColl, int numHoles, int initStartF, int wordCountT, int wordCountS);
void printHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, LabelIndex &labelIndex);
string printTargetHieroPhrase( int startT, int endT, int startS, int endS
, WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex);
string printSourceHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, const LabelIndex &labelIndex);
void preprocessSourceHieroPhrase( int startT, int endT, int startS, int endS
, WordIndex &indexS, HoleCollection &holeColl, const LabelIndex &labelIndex);
void printHieroAlignment( int startT, int endT, int startS, int endS
, const WordIndex &indexS, const WordIndex &indexT, HoleCollection &holeColl, ExtractedRule &rule);
void printAllHieroPhrases( int startT, int endT, int startS, int endS, HoleCollection &holeColl);
inline string IntToString( int i )
{
@ -74,22 +104,29 @@ inline string IntToString( int i )
out << i;
return out.str();
}
};
ofstream extractFile;
ofstream extractFileInv;
set< string > targetLabelCollection, sourceLabelCollection;
map< string, int > targetTopLabelCollection, sourceTopLabelCollection;
// stats for glue grammar and unknown word label probabilities
void collectWordLabelCounts(SentenceAlignmentWithSyntax &sentence );
void writeGlueGrammar(const string &, RuleExtractionOptions &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection);
void writeUnknownWordLabel(const string &);
RuleExtractionOptions options;
int main(int argc, char* argv[])
{
cerr << "extract-rules, written by Philipp Koehn\n"
<< "rule extraction from an aligned parallel corpus\n";
RuleExtractionOptions options;
#ifdef WITH_THREADS
int thread_count = 1;
#endif
if (argc < 5) {
cerr << "syntax: extract-rules corpus.target corpus.source corpus.align extract "
<< " [ --GlueGrammar FILE"
cerr << "syntax: extract-rules corpus.target corpus.source corpus.align extract ["
#ifdef WITH_THREADS
<< " --threads NUM |"
#endif
<< " --GlueGrammar FILE"
<< " | --UnknownWordLabel FILE"
<< " | --OnlyDirect"
<< " | --OutputNTLengths"
@ -218,6 +255,12 @@ int main(int argc, char* argv[])
options.fractionalCounting = false;
} else if (strcmp(argv[i],"--OutputNTLengths") == 0) {
options.outputNTLengths = true;
#ifdef WITH_THREADS
} else if (strcmp(argv[i],"-threads") == 0 ||
strcmp(argv[i],"--threads") == 0 ||
strcmp(argv[i],"--Threads") == 0) {
thread_count = atoi(argv[++i]);
#endif
} else {
cerr << "extract: syntax error, unknown option '" << string(argv[i]) << "'\n";
exit(1);
@ -237,12 +280,28 @@ int main(int argc, char* argv[])
// open output files
string fileNameExtractInv = fileNameExtract + ".inv";
ofstream extractFile;
ofstream extractFileInv;
extractFile.open(fileNameExtract.c_str());
if (!options.onlyDirectFlag)
extractFileInv.open(fileNameExtractInv.c_str());
// output into file
Moses::OutputCollector* extractCollector = new Moses::OutputCollector(&extractFile);
Moses::OutputCollector* extractCollectorInv = new Moses::OutputCollector(&extractFileInv);
// stats on labels for glue grammar and unknown word label probabilities
set< string > targetLabelCollection, sourceLabelCollection;
map< string, int > targetTopLabelCollection, sourceTopLabelCollection;
#ifdef WITH_THREADS
// set up thread pool
Moses::ThreadPool pool(thread_count);
pool.SetQueueLimit(1000);
#endif
// loop through all sentence pairs
int i=0;
size_t i=0;
while(true) {
i++;
if (i%1000 == 0) cerr << "." << flush;
@ -255,11 +314,10 @@ int main(int argc, char* argv[])
if (tFileP->eof()) break;
SAFE_GETLINE((*sFileP), sourceString, LINE_MAX_LENGTH, '\n', __FILE__);
SAFE_GETLINE((*aFileP), alignmentString, LINE_MAX_LENGTH, '\n', __FILE__);
SentenceAlignmentWithSyntax sentence(targetLabelCollection,
sourceLabelCollection,
targetTopLabelCollection,
sourceTopLabelCollection,
options);
SentenceAlignmentWithSyntax *sentence = new SentenceAlignmentWithSyntax
(targetLabelCollection, sourceLabelCollection,
targetTopLabelCollection, sourceTopLabelCollection, options);
//az: output src, tgt, and alingment line
if (options.onlyOutputSpanInfo) {
cout << "LOG: SRC: " << sourceString << endl;
@ -268,18 +326,32 @@ int main(int argc, char* argv[])
cout << "LOG: PHRASES_BEGIN:" << endl;
}
if (sentence.create(targetString, sourceString, alignmentString, i)) {
if (sentence->create(targetString, sourceString, alignmentString, i)) {
if (options.unknownWordLabelFlag) {
collectWordLabelCounts(sentence);
collectWordLabelCounts(*sentence);
}
extractRules(sentence);
consolidateRules();
writeRulesToFile();
extractedRules.clear();
ExtractTask *task = new ExtractTask(i-1, sentence, options, extractCollector, extractCollectorInv);
#ifdef WITH_THREADS
if (thread_count == 1) {
task->Run();
delete task;
}
else {
pool.Submit(task);
}
#else
task->Run();
delete task;
#endif
}
if (options.onlyOutputSpanInfo) cout << "LOG: PHRASES_END:" << endl; //az: mark end of phrases
}
#ifdef WITH_THREADS
// wait for all threads to finish
pool.Stop(true);
#endif
tFile.Close();
sFile.Close();
aFile.Close();
@ -290,23 +362,30 @@ int main(int argc, char* argv[])
}
if (options.glueGrammarFlag)
writeGlueGrammar(fileNameGlueGrammar);
writeGlueGrammar(fileNameGlueGrammar, options, targetLabelCollection, targetTopLabelCollection);
if (options.unknownWordLabelFlag)
writeUnknownWordLabel(fileNameUnknownWordLabel);
}
void extractRules( SentenceAlignmentWithSyntax &sentence )
void ExtractTask::Run() {
extractRules();
consolidateRules();
writeRulesToFile();
m_extractedRules.clear();
}
void ExtractTask::extractRules()
{
int countT = sentence.target.size();
int countS = sentence.source.size();
int countT = m_sentence->target.size();
int countS = m_sentence->source.size();
// phrase repository for creating hiero phrases
RuleExist ruleExist(countT);
// check alignments for target phrase startT...endT
for(int lengthT=1;
lengthT <= options.maxSpan && lengthT <= countT;
lengthT <= m_options.maxSpan && lengthT <= countT;
lengthT++) {
for(int startT=0; startT < countT-(lengthT-1); startT++) {
@ -314,17 +393,17 @@ void extractRules( SentenceAlignmentWithSyntax &sentence )
int endT = startT + lengthT - 1;
// if there is target side syntax, there has to be a node
if (options.targetSyntax && !sentence.targetTree.HasNode(startT,endT))
if (m_options.targetSyntax && !m_sentence->targetTree.HasNode(startT,endT))
continue;
// find find aligned source words
// first: find minimum and maximum source word
int minS = 9999;
int maxS = -1;
vector< int > usedS = sentence.alignedCountS;
vector< int > usedS = m_sentence->alignedCountS;
for(int ti=startT; ti<=endT; ti++) {
for(int i=0; i<sentence.alignedToT[ti].size(); i++) {
int si = sentence.alignedToT[ti][i];
for(unsigned int i=0; i<m_sentence->alignedToT[ti].size(); i++) {
int si = m_sentence->alignedToT[ti][i];
if (si<minS) {
minS = si;
}
@ -340,7 +419,7 @@ void extractRules( SentenceAlignmentWithSyntax &sentence )
continue;
// source phrase has to be within limits
if( maxS-minS >= options.maxSpan )
if( maxS-minS >= m_options.maxSpan )
continue;
// check if source words are aligned to out of bound target words
@ -358,23 +437,23 @@ void extractRules( SentenceAlignmentWithSyntax &sentence )
// start point of source phrase may retreat over unaligned
for(int startS=minS;
(startS>=0 &&
startS>maxS - options.maxSpan && // within length limit
(startS==minS || sentence.alignedCountS[startS]==0)); // unaligned
startS>maxS - m_options.maxSpan && // within length limit
(startS==minS || m_sentence->alignedCountS[startS]==0)); // unaligned
startS--) {
// end point of source phrase may advance over unaligned
for(int endS=maxS;
(endS<countS && endS<startS + options.maxSpan && // within length limit
(endS==maxS || sentence.alignedCountS[endS]==0)); // unaligned
(endS<countS && endS<startS + m_options.maxSpan && // within length limit
(endS==maxS || m_sentence->alignedCountS[endS]==0)); // unaligned
endS++) {
// if there is source side syntax, there has to be a node
if (options.sourceSyntax && !sentence.sourceTree.HasNode(startS,endS))
if (m_options.sourceSyntax && !m_sentence->sourceTree.HasNode(startS,endS))
continue;
// TODO: loop over all source and target syntax labels
// if within length limits, add as fully-lexical phrase pair
if (endT-startT < options.maxSymbolsTarget && endS-startS < options.maxSymbolsSource) {
addRule(sentence,startT,endT,startS,endS, ruleExist);
if (endT-startT < m_options.maxSymbolsTarget && endS-startS < m_options.maxSymbolsSource) {
addRule(startT,endT,startS,endS, ruleExist);
}
// take note that this is a valid phrase alignment
@ -383,10 +462,10 @@ void extractRules( SentenceAlignmentWithSyntax &sentence )
// extract hierarchical rules
// are rules not allowed to start non-terminals?
int initStartT = options.nonTermFirstWord ? startT : startT + 1;
int initStartT = m_options.nonTermFirstWord ? startT : startT + 1;
HoleCollection holeColl(startS, endS); // empty hole collection
addHieroRule(sentence, startT, endT, startS, endS,
addHieroRule(startT, endT, startS, endS,
ruleExist, holeColl, 0, initStartT,
endT-startT+1, endS-startS+1);
}
@ -395,8 +474,7 @@ void extractRules( SentenceAlignmentWithSyntax &sentence )
}
}
void preprocessSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
void ExtractTask::preprocessSourceHieroPhrase( int startT, int endT, int startS, int endS
, WordIndex &indexS, HoleCollection &holeColl, const LabelIndex &labelIndex)
{
vector<Hole*>::iterator iterHoleList = holeColl.GetSortedSourceHoles().begin();
@ -416,8 +494,8 @@ void preprocessSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
Hole &hole = **iterHoleList;
int labelI = labelIndex[ 2+holeCount+holeTotal ];
string label = options.sourceSyntax ?
sentence.sourceTree.GetNodes(currPos,hole.GetEnd(0))[ labelI ]->GetLabel() : "X";
string label = m_options.sourceSyntax ?
m_sentence->sourceTree.GetNodes(currPos,hole.GetEnd(0))[ labelI ]->GetLabel() : "X";
hole.SetLabel(label, 0);
currPos = hole.GetEnd(0);
@ -434,8 +512,7 @@ void preprocessSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
assert(iterHoleList == holeColl.GetSortedSourceHoles().end());
}
string printTargetHieroPhrase(SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
string ExtractTask::printTargetHieroPhrase( int startT, int endT, int startS, int endS
, WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex)
{
HoleList::iterator iterHoleList = holeColl.GetHoles().begin();
@ -458,8 +535,8 @@ string printTargetHieroPhrase(SentenceAlignmentWithSyntax &sentence
assert(sourceLabel != "");
int labelI = labelIndex[ 2+holeCount ];
string targetLabel = options.targetSyntax ?
sentence.targetTree.GetNodes(currPos,hole.GetEnd(1))[ labelI ]->GetLabel() : "X";
string targetLabel = m_options.targetSyntax ?
m_sentence->targetTree.GetNodes(currPos,hole.GetEnd(1))[ labelI ]->GetLabel() : "X";
hole.SetLabel(targetLabel, 1);
out += "[" + sourceLabel + "][" + targetLabel + "] ";
@ -470,7 +547,7 @@ string printTargetHieroPhrase(SentenceAlignmentWithSyntax &sentence
holeCount++;
} else {
indexT[currPos] = outPos;
out += sentence.target[currPos] + " ";
out += m_sentence->target[currPos] + " ";
}
outPos++;
@ -480,8 +557,7 @@ string printTargetHieroPhrase(SentenceAlignmentWithSyntax &sentence
return out.erase(out.size()-1);
}
string printSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
string ExtractTask::printSourceHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, const LabelIndex &labelIndex)
{
vector<Hole*>::iterator iterHoleList = holeColl.GetSortedSourceHoles().begin();
@ -511,7 +587,7 @@ string printSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
++iterHoleList;
++holeCount;
} else {
out += sentence.source[currPos] + " ";
out += m_sentence->source[currPos] + " ";
}
outPos++;
@ -521,20 +597,19 @@ string printSourceHieroPhrase( SentenceAlignmentWithSyntax &sentence
return out.erase(out.size()-1);
}
void printHieroAlignment(SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
void ExtractTask::printHieroAlignment( int startT, int endT, int startS, int endS
, const WordIndex &indexS, const WordIndex &indexT, HoleCollection &holeColl, ExtractedRule &rule)
{
// print alignment of words
for(int ti=startT; ti<=endT; ti++) {
WordIndex::const_iterator p = indexT.find(ti);
if (p != indexT.end()) { // does word still exist?
for(int i=0; i<sentence.alignedToT[ti].size(); i++) {
int si = sentence.alignedToT[ti][i];
for(unsigned int i=0; i<m_sentence->alignedToT[ti].size(); i++) {
int si = m_sentence->alignedToT[ti][i];
std::string sourceSymbolIndex = IntToString(indexS.find(si)->second);
std::string targetSymbolIndex = IntToString(p->second);
rule.alignment += sourceSymbolIndex + "-" + targetSymbolIndex + " ";
if (! options.onlyDirectFlag)
if (! m_options.onlyDirectFlag)
rule.alignmentInv += targetSymbolIndex + "-" + sourceSymbolIndex + " ";
}
}
@ -548,7 +623,7 @@ void printHieroAlignment(SentenceAlignmentWithSyntax &sentence
std::string sourceSymbolIndex = IntToString(hole.GetPos(0));
std::string targetSymbolIndex = IntToString(hole.GetPos(1));
rule.alignment += sourceSymbolIndex + "-" + targetSymbolIndex + " ";
if (!options.onlyDirectFlag)
if (!m_options.onlyDirectFlag)
rule.alignmentInv += targetSymbolIndex + "-" + sourceSymbolIndex + " ";
rule.SetSpanLength(hole.GetPos(0), hole.GetSize(0), hole.GetSize(1) ) ;
@ -556,12 +631,12 @@ void printHieroAlignment(SentenceAlignmentWithSyntax &sentence
}
rule.alignment.erase(rule.alignment.size()-1);
if (!options.onlyDirectFlag) {
if (!m_options.onlyDirectFlag) {
rule.alignmentInv.erase(rule.alignmentInv.size()-1);
}
}
void printHieroPhrase( SentenceAlignmentWithSyntax &sentence, int startT, int endT, int startS, int endS
void ExtractTask::printHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, LabelIndex &labelIndex)
{
WordIndex indexS, indexT; // to keep track of word positions in rule
@ -569,50 +644,48 @@ void printHieroPhrase( SentenceAlignmentWithSyntax &sentence, int startT, int en
ExtractedRule rule( startT, endT, startS, endS );
// phrase labels
string targetLabel = options.targetSyntax ?
sentence.targetTree.GetNodes(startT,endT)[ labelIndex[0] ]->GetLabel() : "X";
string sourceLabel = options.sourceSyntax ?
sentence.sourceTree.GetNodes(startS,endS)[ labelIndex[1] ]->GetLabel() : "X";
string targetLabel = m_options.targetSyntax ?
m_sentence->targetTree.GetNodes(startT,endT)[ labelIndex[0] ]->GetLabel() : "X";
string sourceLabel = m_options.sourceSyntax ?
m_sentence->sourceTree.GetNodes(startS,endS)[ labelIndex[1] ]->GetLabel() : "X";
//string sourceLabel = "X";
// create non-terms on the source side
preprocessSourceHieroPhrase(sentence, startT, endT, startS, endS, indexS, holeColl, labelIndex);
preprocessSourceHieroPhrase(startT, endT, startS, endS, indexS, holeColl, labelIndex);
// target
rule.target = printTargetHieroPhrase(sentence, startT, endT, startS, endS, indexT, holeColl, labelIndex)
rule.target = printTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex)
+ " [" + targetLabel + "]";
// source
// holeColl.SortSourceHoles();
rule.source = printSourceHieroPhrase(sentence, startT, endT, startS, endS, holeColl, labelIndex)
rule.source = printSourceHieroPhrase(startT, endT, startS, endS, holeColl, labelIndex)
+ " [" + sourceLabel + "]";
// alignment
printHieroAlignment(sentence, startT, endT, startS, endS, indexS, indexT, holeColl, rule);
printHieroAlignment(startT, endT, startS, endS, indexS, indexT, holeColl, rule);
addRuleToCollection( rule );
}
void printAllHieroPhrases( SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
, HoleCollection &holeColl)
void ExtractTask::printAllHieroPhrases( int startT, int endT, int startS, int endS, HoleCollection &holeColl)
{
LabelIndex labelIndex,labelCount;
// number of target head labels
int numLabels = options.targetSyntax ? sentence.targetTree.GetNodes(startT,endT).size() : 1;
int numLabels = m_options.targetSyntax ? m_sentence->targetTree.GetNodes(startT,endT).size() : 1;
labelCount.push_back(numLabels);
labelIndex.push_back(0);
// number of source head labels
numLabels = options.sourceSyntax ? sentence.sourceTree.GetNodes(startS,endS).size() : 1;
numLabels = m_options.sourceSyntax ? m_sentence->sourceTree.GetNodes(startS,endS).size() : 1;
labelCount.push_back(numLabels);
labelIndex.push_back(0);
// number of target hole labels
for( HoleList::const_iterator hole = holeColl.GetHoles().begin();
hole != holeColl.GetHoles().end(); hole++ ) {
int numLabels = options.targetSyntax ? sentence.targetTree.GetNodes(hole->GetStart(1),hole->GetEnd(1)).size() : 1 ;
int numLabels = m_options.targetSyntax ? m_sentence->targetTree.GetNodes(hole->GetStart(1),hole->GetEnd(1)).size() : 1 ;
labelCount.push_back(numLabels);
labelIndex.push_back(0);
}
@ -622,7 +695,7 @@ void printAllHieroPhrases( SentenceAlignmentWithSyntax &sentence
for( vector<Hole*>::iterator i = holeColl.GetSortedSourceHoles().begin();
i != holeColl.GetSortedSourceHoles().end(); i++ ) {
const Hole &hole = **i;
int numLabels = options.sourceSyntax ? sentence.sourceTree.GetNodes(hole.GetStart(0),hole.GetEnd(0)).size() : 1 ;
int numLabels = m_options.sourceSyntax ? m_sentence->sourceTree.GetNodes(hole.GetStart(0),hole.GetEnd(0)).size() : 1 ;
labelCount.push_back(numLabels);
labelIndex.push_back(0);
}
@ -630,8 +703,8 @@ void printAllHieroPhrases( SentenceAlignmentWithSyntax &sentence
// loop through the holes
bool done = false;
while(!done) {
printHieroPhrase( sentence, startT, endT, startS, endS, holeColl, labelIndex );
for(int i=0; i<labelIndex.size(); i++) {
printHieroPhrase( startT, endT, startS, endS, holeColl, labelIndex );
for(unsigned int i=0; i<labelIndex.size(); i++) {
labelIndex[i]++;
if(labelIndex[i] == labelCount[i]) {
labelIndex[i] = 0;
@ -646,27 +719,26 @@ void printAllHieroPhrases( SentenceAlignmentWithSyntax &sentence
// this function is called recursively
// it pokes a new hole into the phrase pair, and then calls itself for more holes
void addHieroRule( SentenceAlignmentWithSyntax &sentence
, int startT, int endT, int startS, int endS
void ExtractTask::addHieroRule( int startT, int endT, int startS, int endS
, RuleExist &ruleExist, const HoleCollection &holeColl
, int numHoles, int initStartT, int wordCountT, int wordCountS)
{
// done, if already the maximum number of non-terminals in phrase pair
if (numHoles >= options.maxNonTerm)
if (numHoles >= m_options.maxNonTerm)
return;
// find a hole...
for (int startHoleT = initStartT; startHoleT <= endT; ++startHoleT) {
for (int endHoleT = startHoleT+(options.minHoleTarget-1); endHoleT <= endT; ++endHoleT) {
for (int endHoleT = startHoleT+(m_options.minHoleTarget-1); endHoleT <= endT; ++endHoleT) {
// if last non-terminal, enforce word count limit
if (numHoles == options.maxNonTerm-1 && wordCountT - (endHoleT-startT+1) + (numHoles+1) > options.maxSymbolsTarget)
if (numHoles == m_options.maxNonTerm-1 && wordCountT - (endHoleT-startT+1) + (numHoles+1) > m_options.maxSymbolsTarget)
continue;
// determine the number of remaining target words
const int newWordCountT = wordCountT - (endHoleT-startHoleT+1);
// always enforce min word count limit
if (newWordCountT < options.minWords)
if (newWordCountT < m_options.minWords)
continue;
// except the whole span
@ -686,18 +758,18 @@ void addHieroRule( SentenceAlignmentWithSyntax &sentence
const int sourceHoleSize = sourceHole.GetEnd(0)-sourceHole.GetStart(0)+1;
// enforce minimum hole size
if (sourceHoleSize < options.minHoleSource)
if (sourceHoleSize < m_options.minHoleSource)
continue;
// determine the number of remaining source words
const int newWordCountS = wordCountS - sourceHoleSize;
// if last non-terminal, enforce word count limit
if (numHoles == options.maxNonTerm-1 && newWordCountS + (numHoles+1) > options.maxSymbolsSource)
if (numHoles == m_options.maxNonTerm-1 && newWordCountS + (numHoles+1) > m_options.maxSymbolsSource)
continue;
// enforce min word count limit
if (newWordCountS < options.minWords)
if (newWordCountS < m_options.minWords)
continue;
// hole must be subphrase of the source phrase
@ -710,16 +782,16 @@ void addHieroRule( SentenceAlignmentWithSyntax &sentence
continue;
// if consecutive non-terminals are not allowed, also check for source
if (!options.nonTermConsecSource && holeColl.ConsecSource(sourceHole) )
if (!m_options.nonTermConsecSource && holeColl.ConsecSource(sourceHole) )
continue;
// check that rule scope would not exceed limit if sourceHole
// were added
if (holeColl.Scope(sourceHole) > options.maxScope)
if (holeColl.Scope(sourceHole) > m_options.maxScope)
continue;
// require that at least one aligned word is left (unless there are no words at all)
if (options.requireAlignedWord && (newWordCountS > 0 || newWordCountT > 0)) {
if (m_options.requireAlignedWord && (newWordCountS > 0 || newWordCountT > 0)) {
HoleList::const_iterator iterHoleList = holeColl.GetHoles().begin();
bool foundAlignedWord = false;
// loop through all word positions
@ -735,7 +807,7 @@ void addHieroRule( SentenceAlignmentWithSyntax &sentence
}
// covered by word? check if it is aligned
else {
if (sentence.alignedToT[pos].size() > 0)
if (m_sentence->alignedToT[pos].size() > 0)
foundAlignedWord = true;
}
}
@ -751,19 +823,19 @@ void addHieroRule( SentenceAlignmentWithSyntax &sentence
bool allowablePhrase = true;
// maximum words count violation?
if (newWordCountS + (numHoles+1) > options.maxSymbolsSource)
if (newWordCountS + (numHoles+1) > m_options.maxSymbolsSource)
allowablePhrase = false;
if (newWordCountT + (numHoles+1) > options.maxSymbolsTarget)
if (newWordCountT + (numHoles+1) > m_options.maxSymbolsTarget)
allowablePhrase = false;
// passed all checks...
if (allowablePhrase)
printAllHieroPhrases(sentence, startT, endT, startS, endS, copyHoleColl);
printAllHieroPhrases(startT, endT, startS, endS, copyHoleColl);
// recursively search for next hole
int nextInitStartT = options.nonTermConsecTarget ? endHoleT + 1 : endHoleT + 2;
addHieroRule(sentence, startT, endT, startS, endS
int nextInitStartT = m_options.nonTermConsecTarget ? endHoleT + 1 : endHoleT + 2;
addHieroRule(startT, endT, startS, endS
, ruleExist, copyHoleColl, numHoles + 1, nextInitStartT
, newWordCountT, newWordCountS);
}
@ -771,12 +843,11 @@ void addHieroRule( SentenceAlignmentWithSyntax &sentence
}
}
void addRule( SentenceAlignmentWithSyntax &sentence, int startT, int endT, int startS, int endS
, RuleExist &ruleExist)
void ExtractTask::addRule( int startT, int endT, int startS, int endS, RuleExist &ruleExist)
{
// source
if (options.onlyOutputSpanInfo) {
if (m_options.onlyOutputSpanInfo) {
cout << startS << " " << endS << " " << startT << " " << endT << endl;
return;
}
@ -785,49 +856,49 @@ void addRule( SentenceAlignmentWithSyntax &sentence, int startT, int endT, int s
// phrase labels
string targetLabel,sourceLabel;
sourceLabel = options.sourceSyntax ?
sentence.sourceTree.GetNodes(startS,endS)[0]->GetLabel() : "X";
targetLabel = options.targetSyntax ?
sentence.targetTree.GetNodes(startT,endT)[0]->GetLabel() : "X";
sourceLabel = m_options.sourceSyntax ?
m_sentence->sourceTree.GetNodes(startS,endS)[0]->GetLabel() : "X";
targetLabel = m_options.targetSyntax ?
m_sentence->targetTree.GetNodes(startT,endT)[0]->GetLabel() : "X";
// source
rule.source = "";
for(int si=startS; si<=endS; si++)
rule.source += sentence.source[si] + " ";
rule.source += m_sentence->source[si] + " ";
rule.source += "[" + sourceLabel + "]";
// target
rule.target = "";
for(int ti=startT; ti<=endT; ti++)
rule.target += sentence.target[ti] + " ";
rule.target += m_sentence->target[ti] + " ";
rule.target += "[" + targetLabel + "]";
// alignment
for(int ti=startT; ti<=endT; ti++) {
for(int i=0; i<sentence.alignedToT[ti].size(); i++) {
int si = sentence.alignedToT[ti][i];
for(unsigned int i=0; i<m_sentence->alignedToT[ti].size(); i++) {
int si = m_sentence->alignedToT[ti][i];
std::string sourceSymbolIndex = IntToString(si-startS);
std::string targetSymbolIndex = IntToString(ti-startT);
rule.alignment += sourceSymbolIndex + "-" + targetSymbolIndex + " ";
if (!options.onlyDirectFlag)
if (!m_options.onlyDirectFlag)
rule.alignmentInv += targetSymbolIndex + "-" + sourceSymbolIndex + " ";
}
}
rule.alignment.erase(rule.alignment.size()-1);
if (!options.onlyDirectFlag)
if (!m_options.onlyDirectFlag)
rule.alignmentInv.erase(rule.alignmentInv.size()-1);
addRuleToCollection( rule );
}
void addRuleToCollection( ExtractedRule &newRule )
void ExtractTask::addRuleToCollection( ExtractedRule &newRule )
{
// no double-counting of identical rules from overlapping spans
if (!options.duplicateRules) {
if (!m_options.duplicateRules) {
vector<ExtractedRule>::const_iterator rule;
for(rule = extractedRules.begin(); rule != extractedRules.end(); rule++ ) {
for(rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) {
if (rule->source.compare( newRule.source ) == 0 &&
rule->target.compare( newRule.target ) == 0 &&
!(rule->endT < newRule.startT || rule->startT > newRule.endT)) { // overlapping
@ -835,31 +906,31 @@ void addRuleToCollection( ExtractedRule &newRule )
}
}
}
extractedRules.push_back( newRule );
m_extractedRules.push_back( newRule );
}
void consolidateRules()
void ExtractTask::consolidateRules()
{
typedef vector<ExtractedRule>::iterator R;
map<int, map<int, map<int, map<int,int> > > > spanCount;
// compute number of rules per span
if (options.fractionalCounting) {
for(R rule = extractedRules.begin(); rule != extractedRules.end(); rule++ ) {
if (m_options.fractionalCounting) {
for(R rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) {
spanCount[ rule->startT ][ rule->endT ][ rule->startS ][ rule->endS ]++;
}
}
// compute fractional counts
for(R rule = extractedRules.begin(); rule != extractedRules.end(); rule++ ) {
rule->count = 1.0/(float) (options.fractionalCounting ? spanCount[ rule->startT ][ rule->endT ][ rule->startS ][ rule->endS ] : 1.0 );
for(R rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) {
rule->count = 1.0/(float) (m_options.fractionalCounting ? spanCount[ rule->startT ][ rule->endT ][ rule->startS ][ rule->endS ] : 1.0 );
}
// consolidate counts
for(R rule = extractedRules.begin(); rule != extractedRules.end(); rule++ ) {
for(R rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) {
if (rule->count == 0)
continue;
for(R r2 = rule+1; r2 != extractedRules.end(); r2++ ) {
for(R r2 = rule+1; r2 != m_extractedRules.end(); r2++ ) {
if (rule->source.compare( r2->source ) == 0 &&
rule->target.compare( r2->target ) == 0 &&
rule->alignment.compare( r2->alignment ) == 0) {
@ -870,33 +941,37 @@ void consolidateRules()
}
}
void writeRulesToFile()
void ExtractTask::writeRulesToFile()
{
vector<ExtractedRule>::const_iterator rule;
for(rule = extractedRules.begin(); rule != extractedRules.end(); rule++ ) {
ostringstream out;
ostringstream outInv;
for(rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) {
if (rule->count == 0)
continue;
extractFile << rule->source << " ||| "
<< rule->target << " ||| "
<< rule->alignment << " ||| "
<< rule->count;
if (options.outputNTLengths) {
extractFile << " ||| ";
rule->OutputNTLengths(extractFile);
out << rule->source << " ||| "
<< rule->target << " ||| "
<< rule->alignment << " ||| "
<< rule->count;
if (m_options.outputNTLengths) {
out << " ||| ";
rule->OutputNTLengths(out);
}
extractFile << "\n";
out << "\n";
if (!options.onlyDirectFlag) {
extractFileInv << rule->target << " ||| "
<< rule->source << " ||| "
<< rule->alignmentInv << " ||| "
<< rule->count << "\n";
if (!m_options.onlyDirectFlag) {
outInv << rule->target << " ||| "
<< rule->source << " ||| "
<< rule->alignmentInv << " ||| "
<< rule->count << "\n";
}
}
m_extractCollector->Write( m_id, out.str() );
m_extractCollectorInv->Write( m_id, outInv.str() );;
}
void writeGlueGrammar( const string & fileName )
void writeGlueGrammar( const string & fileName, RuleExtractionOptions &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection )
{
ofstream grammarFile;
grammarFile.open(fileName.c_str());
@ -907,7 +982,7 @@ void writeGlueGrammar( const string & fileName )
} else {
// chose a top label that is not already a label
string topLabel = "QQQQQQ";
for( int i=1; i<=topLabel.length(); i++) {
for( unsigned int i=1; i<=topLabel.length(); i++) {
if(targetLabelCollection.find( topLabel.substr(0,i) ) == targetLabelCollection.end() ) {
topLabel = topLabel.substr(0,i);
break;