SoftSourceSyntacticConstraintsFeature: Sparse label pair scores

This commit is contained in:
Matthias Huck 2015-02-26 20:27:02 +00:00
parent a3d2adca50
commit 0afc261251
2 changed files with 63 additions and 46 deletions

View File

@ -24,7 +24,9 @@ SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(con
, m_useCoreSourceLabels(false)
, m_useLogprobs(true)
, m_useSparse(false)
, m_useSparseLabelPairs(false)
, m_noMismatches(false)
, m_floor(1e-7)
{
VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
ReadParameters();
@ -42,6 +44,12 @@ SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(con
} else {
VERBOSE(1, " inactive.");
}
VERBOSE(1, " Sparse label pair scores");
if ( m_useSparseLabelPairs ) {
VERBOSE(1, " active.");
} else {
VERBOSE(1, " inactive.");
}
VERBOSE(1, " Core labels");
if ( m_useCoreSourceLabels ) {
VERBOSE(1, " active.");
@ -73,6 +81,8 @@ void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key,
m_useLogprobs = Scan<bool>(value);
} else if (key == "sparse") {
m_useSparse = Scan<bool>(value);
} else if (key == "sparseLabelPairs") {
m_useSparseLabelPairs = Scan<bool>(value);
} else {
StatelessFeatureFunction::SetParameter(key, value);
}
@ -240,7 +250,8 @@ void SoftSourceSyntacticConstraintsFeature::LoadTargetSourceLeftHandSideJointCou
sourceVector->at(foundSourceLabelIndex->second) = std::pair<float,float>(count,count);
std::pair< boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator, bool > insertedJointCount =
m_labelPairProbabilities.insert( std::pair<const Factor*, std::vector< std::pair<float,float> >* >(targetLabelFactor,sourceVector) );
assert(insertedJointCount.second);
UTIL_THROW_IF2(!insertedJointCount.second, GetScoreProducerDescription()
<< ": Reading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " failed.");
}
}
@ -305,8 +316,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
bool treeInputMismatchLHSBinary = true;
size_t treeInputMismatchRHSCount = 0;
bool hasCompleteTreeInputMatch = false;
float t2sLabelsProb = 1;
float s2tLabelsProb = 1;
float ruleLabelledProbability = 0.0;
float treeInputMatchProbRHS = 0.0;
float treeInputMatchProbLHS = 0.0;
@ -328,14 +337,11 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
boost::unordered_set<size_t> treeInputLabelsLHS;
// get index map for underlying hypotheses
// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
// targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
const WordsRange& wordsRange = inputPath.GetWordsRange();
size_t startPos = wordsRange.GetStartPos();
size_t endPos = wordsRange.GetEndPos();
const Phrase *sourcePhrase = targetPhrase.GetRuleSource();
// std::vector<const Factor*> targetLabelsRHS;
if (nNTs > 1) { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule
size_t nonTerminalNumber = 0;
size_t sourceSentPos = startPos;
@ -346,10 +352,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
size_t symbolStartPos = sourceSentPos;
size_t symbolEndPos = sourceSentPos;
if ( word.IsNonTerminal() ) {
// non-terminal: consult subderivation
// size_t nonTermIndex = nonTermIndexMap[phrasePos];
// targetLabelsRHS.push_back( word[0] );
// retrieve information that is required for input tree label matching (RHS)
const ChartCellLabel &cell = *stackVec->at(nonTerminalNumber);
const WordsRange& prevWordsRange = cell.GetCoverage();
@ -477,25 +479,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
}
}
// if ( hasCompleteTreeInputMatch ) {
//
// std::pair<float,float> probPair = GetLabelPairProbabilities( targetLHS, sourceLabelsLHSIt->first);
// t2sLabelsProb = probPair.first;
// s2tLabelsProb = probPair.second;
// nonTerminalNumber=0;
// for (std::list<size_t>::const_iterator sourceLabelsRHSIt = sourceLabelsRHS.begin();
// sourceLabelsRHSIt != sourceLabelsRHS.end(); ++sourceLabelsRHSIt, ++nonTerminalNumber) {
// probPair = GetLabelPairProbabilities( targetLabelsRHS[nonTerminalNumber], *sourceLabelsRHSIt );
// t2sLabelsProb += probPair.first;
// s2tLabelsProb += probPair.second;
// }
// t2sLabelsProb /= nNTs;
// s2tLabelsProb /= nNTs;
// assert(t2sLabelsProb != 0);
// assert(s2tLabelsProb != 0);
// }
}
// normalization
@ -546,26 +529,62 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
// LHS
if ( m_useSparse ) {
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
// score sparse features: RHS mismatch
scoreBreakdown.PlusEquals(this,
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
score_LHS_0);
}
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
// score sparse features: RHS mismatch
scoreBreakdown.PlusEquals(this,
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
score_LHS_0);
}
}
}
}
if ( m_useSparseLabelPairs ) {
// left-hand side label pairs (target NT, source NT)
float t2sLabelsScore = 0.0;
float s2tLabelsScore = 0.0;
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
scoreBreakdown.PlusEquals(this,
"LHSPAIR_" + targetLHS->GetString().as_string() + "_" + m_sourceLabelsByIndex[*treeInputLabelsLHSIt],
(float)1/treeInputLabelsLHS.size());
if (!m_targetSourceLHSJointCountFile.empty()) {
std::pair<float,float> probPair = GetLabelPairProbabilities( targetLHS, *treeInputLabelsLHSIt);
t2sLabelsScore += probPair.first;
s2tLabelsScore += probPair.second;
}
}
if ( treeInputLabelsLHS.size() == 0 ) {
scoreBreakdown.PlusEquals(this,
"LHSPAIR_" + targetLHS->GetString().as_string() + "_" + outputDefaultNonTerminal[0]->GetString().as_string(),
1);
if (!m_targetSourceLHSJointCountFile.empty()) {
t2sLabelsScore = TransformScore(m_floor);
s2tLabelsScore = TransformScore(m_floor);
}
} else {
if (!m_targetSourceLHSJointCountFile.empty()) {
float norm = TransformScore(treeInputLabelsLHS.size());
t2sLabelsScore = TransformScore(t2sLabelsScore) - norm;
s2tLabelsScore = TransformScore(s2tLabelsScore) - norm;
}
}
if (!m_targetSourceLHSJointCountFile.empty()) {
scoreBreakdown.PlusEquals(this, "LHST2S", t2sLabelsScore);
scoreBreakdown.PlusEquals(this, "LHSS2T", s2tLabelsScore);
}
}
} else {
// abort with error message if the phrase does not translate an unknown word
@ -589,10 +608,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
newScores[1] = treeInputMismatchLHSBinary;
newScores[2] = treeInputMismatchRHSCount;
// newScores[3] = hasCompleteTreeInputMatch ? TransformScore(ruleLabelledProbability) : 0;
// newScores[4] = hasCompleteTreeInputMatch ? TransformScore(t2sLabelsProb) : 0;
// newScores[5] = hasCompleteTreeInputMatch ? TransformScore(s2tLabelsProb) : 0;
if ( m_useLogprobs ) {
if ( ruleLabelledProbability != 0 ) {
ruleLabelledProbability = TransformScore(ruleLabelledProbability);
@ -617,7 +632,7 @@ std::pair<float,float> SoftSourceSyntacticConstraintsFeature::GetLabelPairProbab
boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::const_iterator found =
m_labelPairProbabilities.find(target);
if ( found == m_labelPairProbabilities.end() ) {
return std::pair<float,float>(0,0);
return std::pair<float,float>(m_floor,m_floor); // floor values
}
return found->second->at(source);
}

View File

@ -71,7 +71,9 @@ protected:
bool m_useCoreSourceLabels;
bool m_useLogprobs;
bool m_useSparse;
bool m_useSparseLabelPairs;
bool m_noMismatches;
float m_floor;
boost::unordered_map<std::string,size_t> m_sourceLabels;
std::vector<std::string> m_sourceLabelsByIndex;