mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-25 12:52:29 +03:00
SoftSourceSyntacticConstraintsFeature: Sparse label pair scores
This commit is contained in:
parent
a3d2adca50
commit
0afc261251
@ -24,7 +24,9 @@ SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(con
|
||||
, m_useCoreSourceLabels(false)
|
||||
, m_useLogprobs(true)
|
||||
, m_useSparse(false)
|
||||
, m_useSparseLabelPairs(false)
|
||||
, m_noMismatches(false)
|
||||
, m_floor(1e-7)
|
||||
{
|
||||
VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
|
||||
ReadParameters();
|
||||
@ -42,6 +44,12 @@ SoftSourceSyntacticConstraintsFeature::SoftSourceSyntacticConstraintsFeature(con
|
||||
} else {
|
||||
VERBOSE(1, " inactive.");
|
||||
}
|
||||
VERBOSE(1, " Sparse label pair scores");
|
||||
if ( m_useSparseLabelPairs ) {
|
||||
VERBOSE(1, " active.");
|
||||
} else {
|
||||
VERBOSE(1, " inactive.");
|
||||
}
|
||||
VERBOSE(1, " Core labels");
|
||||
if ( m_useCoreSourceLabels ) {
|
||||
VERBOSE(1, " active.");
|
||||
@ -73,6 +81,8 @@ void SoftSourceSyntacticConstraintsFeature::SetParameter(const std::string& key,
|
||||
m_useLogprobs = Scan<bool>(value);
|
||||
} else if (key == "sparse") {
|
||||
m_useSparse = Scan<bool>(value);
|
||||
} else if (key == "sparseLabelPairs") {
|
||||
m_useSparseLabelPairs = Scan<bool>(value);
|
||||
} else {
|
||||
StatelessFeatureFunction::SetParameter(key, value);
|
||||
}
|
||||
@ -240,7 +250,8 @@ void SoftSourceSyntacticConstraintsFeature::LoadTargetSourceLeftHandSideJointCou
|
||||
sourceVector->at(foundSourceLabelIndex->second) = std::pair<float,float>(count,count);
|
||||
std::pair< boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::iterator, bool > insertedJointCount =
|
||||
m_labelPairProbabilities.insert( std::pair<const Factor*, std::vector< std::pair<float,float> >* >(targetLabelFactor,sourceVector) );
|
||||
assert(insertedJointCount.second);
|
||||
UTIL_THROW_IF2(!insertedJointCount.second, GetScoreProducerDescription()
|
||||
<< ": Reading target/source label joint counts from file " << m_targetSourceLHSJointCountFile << " failed.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -305,8 +316,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
bool treeInputMismatchLHSBinary = true;
|
||||
size_t treeInputMismatchRHSCount = 0;
|
||||
bool hasCompleteTreeInputMatch = false;
|
||||
float t2sLabelsProb = 1;
|
||||
float s2tLabelsProb = 1;
|
||||
float ruleLabelledProbability = 0.0;
|
||||
float treeInputMatchProbRHS = 0.0;
|
||||
float treeInputMatchProbLHS = 0.0;
|
||||
@ -328,14 +337,11 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
boost::unordered_set<size_t> treeInputLabelsLHS;
|
||||
|
||||
// get index map for underlying hypotheses
|
||||
// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
// targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
const WordsRange& wordsRange = inputPath.GetWordsRange();
|
||||
size_t startPos = wordsRange.GetStartPos();
|
||||
size_t endPos = wordsRange.GetEndPos();
|
||||
const Phrase *sourcePhrase = targetPhrase.GetRuleSource();
|
||||
|
||||
// std::vector<const Factor*> targetLabelsRHS;
|
||||
if (nNTs > 1) { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule
|
||||
size_t nonTerminalNumber = 0;
|
||||
size_t sourceSentPos = startPos;
|
||||
@ -346,10 +352,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
size_t symbolStartPos = sourceSentPos;
|
||||
size_t symbolEndPos = sourceSentPos;
|
||||
if ( word.IsNonTerminal() ) {
|
||||
// non-terminal: consult subderivation
|
||||
// size_t nonTermIndex = nonTermIndexMap[phrasePos];
|
||||
// targetLabelsRHS.push_back( word[0] );
|
||||
|
||||
// retrieve information that is required for input tree label matching (RHS)
|
||||
const ChartCellLabel &cell = *stackVec->at(nonTerminalNumber);
|
||||
const WordsRange& prevWordsRange = cell.GetCoverage();
|
||||
@ -477,25 +479,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// if ( hasCompleteTreeInputMatch ) {
|
||||
//
|
||||
// std::pair<float,float> probPair = GetLabelPairProbabilities( targetLHS, sourceLabelsLHSIt->first);
|
||||
// t2sLabelsProb = probPair.first;
|
||||
// s2tLabelsProb = probPair.second;
|
||||
// nonTerminalNumber=0;
|
||||
// for (std::list<size_t>::const_iterator sourceLabelsRHSIt = sourceLabelsRHS.begin();
|
||||
// sourceLabelsRHSIt != sourceLabelsRHS.end(); ++sourceLabelsRHSIt, ++nonTerminalNumber) {
|
||||
// probPair = GetLabelPairProbabilities( targetLabelsRHS[nonTerminalNumber], *sourceLabelsRHSIt );
|
||||
// t2sLabelsProb += probPair.first;
|
||||
// s2tLabelsProb += probPair.second;
|
||||
// }
|
||||
// t2sLabelsProb /= nNTs;
|
||||
// s2tLabelsProb /= nNTs;
|
||||
// assert(t2sLabelsProb != 0);
|
||||
// assert(s2tLabelsProb != 0);
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
// normalization
|
||||
@ -546,26 +529,62 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
|
||||
// LHS
|
||||
|
||||
if ( m_useSparse ) {
|
||||
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
|
||||
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
|
||||
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
|
||||
|
||||
float score_LHS_0 = (float)1/treeInputLabelsLHS.size();
|
||||
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
|
||||
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
|
||||
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
|
||||
|
||||
if ( !m_useCoreSourceLabels || m_coreSourceLabels.find(*treeInputLabelsLHSIt) != m_coreSourceLabels.end() ) {
|
||||
|
||||
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
|
||||
// score sparse features: RHS mismatch
|
||||
scoreBreakdown.PlusEquals(this,
|
||||
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
|
||||
score_LHS_0);
|
||||
}
|
||||
if (sparseScoredTreeInputLabelsLHS.find(*treeInputLabelsLHSIt) == sparseScoredTreeInputLabelsLHS.end()) {
|
||||
// score sparse features: RHS mismatch
|
||||
scoreBreakdown.PlusEquals(this,
|
||||
m_sourceLabelsByIndex_LHS_0[*treeInputLabelsLHSIt],
|
||||
score_LHS_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if ( m_useSparseLabelPairs ) {
|
||||
|
||||
// left-hand side label pairs (target NT, source NT)
|
||||
float t2sLabelsScore = 0.0;
|
||||
float s2tLabelsScore = 0.0;
|
||||
for (boost::unordered_set<size_t>::const_iterator treeInputLabelsLHSIt = treeInputLabelsLHS.begin();
|
||||
treeInputLabelsLHSIt != treeInputLabelsLHS.end(); ++treeInputLabelsLHSIt) {
|
||||
|
||||
scoreBreakdown.PlusEquals(this,
|
||||
"LHSPAIR_" + targetLHS->GetString().as_string() + "_" + m_sourceLabelsByIndex[*treeInputLabelsLHSIt],
|
||||
(float)1/treeInputLabelsLHS.size());
|
||||
|
||||
if (!m_targetSourceLHSJointCountFile.empty()) {
|
||||
std::pair<float,float> probPair = GetLabelPairProbabilities( targetLHS, *treeInputLabelsLHSIt);
|
||||
t2sLabelsScore += probPair.first;
|
||||
s2tLabelsScore += probPair.second;
|
||||
}
|
||||
}
|
||||
if ( treeInputLabelsLHS.size() == 0 ) {
|
||||
scoreBreakdown.PlusEquals(this,
|
||||
"LHSPAIR_" + targetLHS->GetString().as_string() + "_" + outputDefaultNonTerminal[0]->GetString().as_string(),
|
||||
1);
|
||||
if (!m_targetSourceLHSJointCountFile.empty()) {
|
||||
t2sLabelsScore = TransformScore(m_floor);
|
||||
s2tLabelsScore = TransformScore(m_floor);
|
||||
}
|
||||
} else {
|
||||
if (!m_targetSourceLHSJointCountFile.empty()) {
|
||||
float norm = TransformScore(treeInputLabelsLHS.size());
|
||||
t2sLabelsScore = TransformScore(t2sLabelsScore) - norm;
|
||||
s2tLabelsScore = TransformScore(s2tLabelsScore) - norm;
|
||||
}
|
||||
}
|
||||
if (!m_targetSourceLHSJointCountFile.empty()) {
|
||||
scoreBreakdown.PlusEquals(this, "LHST2S", t2sLabelsScore);
|
||||
scoreBreakdown.PlusEquals(this, "LHSS2T", s2tLabelsScore);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
// abort with error message if the phrase does not translate an unknown word
|
||||
@ -589,10 +608,6 @@ void SoftSourceSyntacticConstraintsFeature::EvaluateWithSourceContext(const Inpu
|
||||
newScores[1] = treeInputMismatchLHSBinary;
|
||||
newScores[2] = treeInputMismatchRHSCount;
|
||||
|
||||
// newScores[3] = hasCompleteTreeInputMatch ? TransformScore(ruleLabelledProbability) : 0;
|
||||
// newScores[4] = hasCompleteTreeInputMatch ? TransformScore(t2sLabelsProb) : 0;
|
||||
// newScores[5] = hasCompleteTreeInputMatch ? TransformScore(s2tLabelsProb) : 0;
|
||||
|
||||
if ( m_useLogprobs ) {
|
||||
if ( ruleLabelledProbability != 0 ) {
|
||||
ruleLabelledProbability = TransformScore(ruleLabelledProbability);
|
||||
@ -617,7 +632,7 @@ std::pair<float,float> SoftSourceSyntacticConstraintsFeature::GetLabelPairProbab
|
||||
boost::unordered_map<const Factor*, std::vector< std::pair<float,float> >* >::const_iterator found =
|
||||
m_labelPairProbabilities.find(target);
|
||||
if ( found == m_labelPairProbabilities.end() ) {
|
||||
return std::pair<float,float>(0,0);
|
||||
return std::pair<float,float>(m_floor,m_floor); // floor values
|
||||
}
|
||||
return found->second->at(source);
|
||||
}
|
||||
|
@ -71,7 +71,9 @@ protected:
|
||||
bool m_useCoreSourceLabels;
|
||||
bool m_useLogprobs;
|
||||
bool m_useSparse;
|
||||
bool m_useSparseLabelPairs;
|
||||
bool m_noMismatches;
|
||||
float m_floor;
|
||||
|
||||
boost::unordered_map<std::string,size_t> m_sourceLabels;
|
||||
std::vector<std::string> m_sourceLabelsByIndex;
|
||||
|
Loading…
Reference in New Issue
Block a user