SoftSourceSyntacticConstraintsFeature: better config parameter names

This commit is contained in:
Matthias Huck 2015-01-27 18:15:51 +00:00
parent 0a0ea437bb
commit 9f562e0fd4
2 changed files with 57 additions and 35 deletions

View File

@ -19,12 +19,21 @@ namespace Moses
{
SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(const std::string &line)
: StatelessFeatureFunction(6, line), m_featureVariant(0)
: StatelessFeatureFunction(6, line)
, m_useCoreSourceLabels(false)
, m_useLogprobs(true)
, m_useSparse(false)
, m_noMismatches(false)
{
VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
ReadParameters();
VERBOSE(1, " Done.");
VERBOSE(1, " Feature variant: " << m_featureVariant << "." << std::endl);
VERBOSE(1, " Config:");
VERBOSE(1, " Log probabilities"); if ( m_useLogprobs ) { VERBOSE(1, " active."); } else { VERBOSE(1, " inactive."); }
VERBOSE(1, " Sparse scores"); if ( m_useSparse ) { VERBOSE(1, " active."); } else { VERBOSE(1, " inactive."); }
VERBOSE(1, " Core labels"); if ( m_useCoreSourceLabels ) { VERBOSE(1, " active."); } else { VERBOSE(1, " inactive."); }
VERBOSE(1, " No mismatches"); if ( m_noMismatches ) { VERBOSE(1, " active."); } else { VERBOSE(1, " inactive."); }
VERBOSE(1, std::endl);
}
void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key, const std::string& value)
@ -33,10 +42,15 @@ void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key,
m_sourceLabelSetFile = value;
} else if (key == "coreSourceLabelSetFile") {
m_coreSourceLabelSetFile = value;
m_useCoreSourceLabels = true;
} else if (key == "targetSourceLeftHandSideJointCountFile") {
m_targetSourceLHSJointCountFile = value;
} else if (key == "featureVariant") {
m_featureVariant = Scan<size_t>(value); // 0: only dense features, 1: no mismatches (also set weights 1 0 0 and tuneable=false), 2: with sparse features, 3: with sparse features for core labels only
} else if (key == "noMismatches") {
m_noMismatches = Scan<bool>(value); // for a hard constraint, allow no mismatches (also set: weights 1 0 0 0 0 0, tuneable=false)
} else if (key == "logProbabilities") {
m_useLogprobs = Scan<bool>(value);
} else if (key == "sparse") {
m_useSparse = Scan<bool>(value);
} else {
StatelessFeatureFunction::SetParameter(key, value);
}
@ -47,7 +61,7 @@ void SoftSourceSyntacticConstraintsFeature::Load()
{
// don't change the loading order!
LoadSourceLabelSet();
if (m_featureVariant == 3) {
if (!m_coreSourceLabelSetFile.empty()) {
LoadCoreSourceLabelSet();
}
if (!m_targetSourceLHSJointCountFile.empty()) {
@ -272,7 +286,7 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
float t2sLabelsProb = 1;
float s2tLabelsProb = 1;
float ruleLabelledProbability = 0.0;
float treeInputMatchLogprobRHS = 0.0;
float treeInputMatchProbRHS = 0.0;
float treeInputMatchProbLHS = 0.0;
// read SourceLabels property
@ -389,8 +403,8 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
treeInputMatchRHSCountByNonTerminal[nonTerminalNumber] = true;
treeInputMatchProbRHSByNonTerminal[nonTerminalNumber] += sourceLabelsRHSCount; // to be normalized later on
if ( m_featureVariant == 2 ||
(m_featureVariant == 3 && m_coreSourceLabels.find(*sourceLabelsRHSIt) != m_coreSourceLabels.end()) ) {
if ( m_useSparse &&
(!m_useCoreSourceLabels || m_coreSourceLabels.find(*sourceLabelsRHSIt) != m_coreSourceLabels.end()) ) {
// score sparse features: RHS match
if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*sourceLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) {
// (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem)
@ -421,8 +435,8 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
treeInputMismatchLHSBinary = false;
treeInputMatchProbLHS += sourceLabelsLHSIt->second; // to be normalized later on
if ( m_featureVariant == 2 ||
(m_featureVariant == 3 && m_coreSourceLabels.find(sourceLabelsLHSIt->first) != m_coreSourceLabels.end()) ) {
if ( m_useSparse &&
(!m_useCoreSourceLabels || m_coreSourceLabels.find(sourceLabelsLHSIt->first) != m_coreSourceLabels.end()) ) {
// score sparse features: LHS match
if (sparseScoredTreeInputLabelsLHS.find(sourceLabelsLHSIt->first) == sparseScoredTreeInputLabelsLHS.end()) {
// (only if no match has been scored for this tree input label and rule non-terminal with a previous sourceLabelItem)
@ -467,7 +481,7 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
treeInputMatchProbRHSByNonTerminalIt != treeInputMatchProbRHSByNonTerminal.end(); ++treeInputMatchProbRHSByNonTerminalIt) {
*treeInputMatchProbRHSByNonTerminalIt /= totalCount;
if ( *treeInputMatchProbRHSByNonTerminalIt != 0 ) {
treeInputMatchLogprobRHS += TransformScore(*treeInputMatchProbRHSByNonTerminalIt);
treeInputMatchProbRHS += ( m_useLogprobs ? TransformScore(*treeInputMatchProbRHSByNonTerminalIt) : *treeInputMatchProbRHSByNonTerminalIt );
}
}
treeInputMatchProbLHS /= totalCount;
@ -485,7 +499,7 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
}
// score sparse features: mismatches
if ( m_featureVariant == 2 || m_featureVariant == 3 ) {
if ( m_useSparse ) {
// RHS
@ -496,8 +510,7 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsRHSIt = treeInputLabelsRHS[nonTerminalNumber].begin();
treeInputLabelsRHSIt != treeInputLabelsRHS[nonTerminalNumber].end(); ++treeInputLabelsRHSIt) {
if ( m_featureVariant == 2 ||
(m_featureVariant == 3 && m_coreSourceLabels.find(*treeInputLabelsRHSIt) != m_coreSourceLabels.end()) ) {
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsRHSIt) != m_coreSourceLabels.end() ) {
if (sparseScoredTreeInputLabelsRHS[nonTerminalNumber].find(*treeInputLabelsRHSIt) == sparseScoredTreeInputLabelsRHS[nonTerminalNumber].end()) {
// score sparse features: RHS mismatch
@ -511,18 +524,20 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
// LHS
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
if ( m_useSparse ) {
if ( m_featureVariant == 2 ||
(m_featureVariant == 3 && m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end()) ) {
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
// score sparse features: RHS mismatch
scoreBreakdown.PlusEquals(this,
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
score_LHS_0);
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
// score sparse features: RHS mismatch
scoreBreakdown.PlusEquals(this,
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
score_LHS_0);
}
}
}
}
@ -545,14 +560,9 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
// add scores
// input tree matching
switch (m_featureVariant) {
case 1:
newScores[0] = !hasCompleteTreeInputMatch;
if ( m_noMismatches ) {
newScores[0] = ( (hasCompleteTreeInputMatch || isGlueGrammarRule || isUnkRule) ? 0 : -std::numeric_limits<float>::infinity() );
break;
default:
newScores[0] = !hasCompleteTreeInputMatch;
}
newScores[1] = treeInputMismatchLHSBinary;
newScores[2] = treeInputMismatchRHSCount;
@ -561,9 +571,18 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
// newScores[4] = hasCompleteTreeInputMatch ? TransformScore(t2sLabelsProb) : 0;
// newScores[5] = hasCompleteTreeInputMatch ? TransformScore(s2tLabelsProb) : 0;
newScores[3] = (ruleLabelledProbability != 0) ? TransformScore(ruleLabelledProbability) : 0;
newScores[4] = (treeInputMatchProbLHS != 0) ? TransformScore(treeInputMatchProbLHS) : 0;
newScores[5] = treeInputMatchLogprobRHS;
if ( m_useLogprobs ) {
if ( ruleLabelledProbability != 0 ) {
ruleLabelledProbability = TransformScore(ruleLabelledProbability);
}
if ( treeInputMatchProbLHS != 0 ) {
treeInputMatchProbLHS = TransformScore(treeInputMatchProbLHS);
}
}
newScores[3] = ruleLabelledProbability;
newScores[4] = treeInputMatchProbLHS;
newScores[5] = treeInputMatchProbRHS;
scoreBreakdown.PlusEquals(this, newScores);
}

View File

@ -62,7 +62,10 @@ private:
std::string m_coreSourceLabelSetFile;
std::string m_targetSourceLHSJointCountFile;
std::string m_unknownLeftHandSideFile;
size_t m_featureVariant;
bool m_useCoreSourceLabels;
bool m_useLogprobs;
bool m_useSparse;
bool m_noMismatches;
boost::unordered_map<std::string,size_t> m_sourceLabels;
std::vector<std::string> m_sourceLabelsByIndex;