add span length to training

This commit is contained in:
Hieu Hoang 2014-05-31 21:39:47 +01:00
parent 3ecd625920
commit ea1fb296fe
8 changed files with 32 additions and 9 deletions

View File

@ -38,8 +38,8 @@ int main(int argc, char** argv)
("HieroSourceLHS", "Always use Hiero source LHS? Default = 0")
("MaxSpanFreeNonTermSource", po::value<int>()->default_value(params.maxSpanFreeNonTermSource), "Max number of words covered by beginning/end NT. Default = 0 (no limit)")
("NoNieceTerminal", "Don't extract rule if 1 of the non-term covers the same word as 1 of the terminals")
("MaxScope", po::value<int>()->default_value(params.maxScope), "maximum scope (see Hopkins and Langmead (2010)). Default is HIGH");
("MaxScope", po::value<int>()->default_value(params.maxScope), "maximum scope (see Hopkins and Langmead (2010)). Default is HIGH")
("SpanLength", po::value<bool>()->default_value(params.spanLength), "Output span length of RHS each non-term");
po::variables_map vm;
try
@ -80,6 +80,7 @@ int main(int argc, char** argv)
if (vm.count("HieroSourceLHS")) params.hieroSourceLHS = true;
if (vm.count("MaxSpanFreeNonTermSource")) params.maxSpanFreeNonTermSource = vm["MaxSpanFreeNonTermSource"].as<int>();
if (vm.count("NoNieceTerminal")) params.nieceTerminal = false;
if (vm.count("SpanLength")) params.spanLength = true;
if (vm.count("MaxScope")) params.maxScope = vm["MaxScope"].as<int>();
// input files;
@ -142,8 +143,8 @@ int main(int argc, char** argv)
rules.Consolidate(params);
//cerr << rules.Debug();
rules.Output(extractFile, true);
rules.Output(extractInvFile, false);
rules.Output(extractFile, true, params);
rules.Output(extractInvFile, false, params);
delete alignedSentence;

View File

@ -29,6 +29,7 @@ Parameter::Parameter()
,hieroSourceLHS(false)
,maxSpanFreeNonTermSource(0)
,nieceTerminal(true)
,spanLength(false)
,maxScope(UNDEFINED)
{}

View File

@ -41,6 +41,7 @@ public:
bool hieroSourceLHS;
int maxSpanFreeNonTermSource;
bool nieceTerminal;
bool spanLength;
int maxScope;
};

View File

@ -121,7 +121,7 @@ std::string Rule::Debug() const
return out.str();
}
void Rule::Output(std::ostream &out, bool forward) const
void Rule::Output(std::ostream &out, bool forward, const Parameter &params) const
{
if (forward) {
// source
@ -167,6 +167,18 @@ void Rule::Output(std::ostream &out, bool forward) const
out << m_count;
out << " ||| ";
// span length
if (params.spanLength) {
out << "{{SpanLength ";
for (size_t i = 0; i < m_nonterms.size(); ++i) {
const NonTerm &nonTerm = *m_nonterms[i];
const ConsistentPhrase &cp = nonTerm.GetConsistentPhrase();
out << i << "," << cp.GetWidth(Moses::Input) << "," << cp.GetWidth(Moses::Output) << " ";
}
out << "}} ";
}
}
void Rule::Prevalidate(const Parameter &params)

View File

@ -52,7 +52,7 @@ public:
{ return m_alignments; }
std::string Debug() const;
void Output(std::ostream &out, bool forward) const;
void Output(std::ostream &out, bool forward, const Parameter &params) const;
void Prevalidate(const Parameter &params);
void CreateTarget(const Parameter &params);

View File

@ -151,12 +151,12 @@ std::string Rules::Debug() const
return out.str();
}
void Rules::Output(std::ostream &out, bool forward) const
void Rules::Output(std::ostream &out, bool forward, const Parameter &params) const
{
std::set<Rule*, CompareRules>::const_iterator iter;
for (iter = m_mergeRules.begin(); iter != m_mergeRules.end(); ++iter) {
const Rule &rule = **iter;
rule.Output(out, forward);
rule.Output(out, forward, params);
out << endl;
}
}

View File

@ -48,7 +48,7 @@ public:
void Consolidate(const Parameter &params);
std::string Debug() const;
void Output(std::ostream &out, bool forward) const;
void Output(std::ostream &out, bool forward, const Parameter &params) const;
protected:
const AlignedSentence &m_alignedSentence;

View File

@ -61,6 +61,7 @@ bool lexFlag = true;
bool unalignedFlag = false;
bool unalignedFWFlag = false;
bool crossedNonTerm = false;
bool spanLength = false;
int countOfCounts[COC_MAX+1];
int totalDistinct = 0;
float minCountHierarchical = 0;
@ -174,6 +175,9 @@ int main(int argc, char* argv[])
} else if (strcmp(argv[i],"--CrossedNonTerm") == 0) {
crossedNonTerm = true;
std::cerr << "crossed non-term reordering feature" << std::endl;
} else if (strcmp(argv[i],"--SpanLength") == 0) {
spanLength = true;
std::cerr << "span length feature" << std::endl;
} else {
featureArgs.push_back(argv[i]);
++i;
@ -659,6 +663,10 @@ void outputPhrasePair(const ExtractionPhrasePair &phrasePair,
}
}
if (spanLength && !inverseFlag) {
phraseTableFile << " {{SpanLength " << "asasa" << "}}";
}
phraseTableFile << std::endl;
}