From bbf8a615f2db1a3eebdc27e7734bab124ffb2728 Mon Sep 17 00:00:00 2001 From: Matthias Huck Date: Thu, 17 Mar 2016 16:10:49 +0000 Subject: [PATCH] Sparse distortion features in the manner of: Spence Green, Michel Galley, Christopher D. Manning. 2010. Improved Models of Distortion Cost for Statistical Machine Translation. In NAACL 2010. --- moses/FF/DistortionScoreProducer.cpp | 217 +++++++++++++++++++++++++-- moses/FF/DistortionScoreProducer.h | 19 ++- 2 files changed, 219 insertions(+), 17 deletions(-) diff --git a/moses/FF/DistortionScoreProducer.cpp b/moses/FF/DistortionScoreProducer.cpp index 0cc9dfe10..73f0821e0 100644 --- a/moses/FF/DistortionScoreProducer.cpp +++ b/moses/FF/DistortionScoreProducer.cpp @@ -1,26 +1,31 @@ #include "DistortionScoreProducer.h" #include "FFState.h" +#include "moses/InputPath.h" #include "moses/Range.h" #include "moses/StaticData.h" #include "moses/Hypothesis.h" #include "moses/Manager.h" +#include "moses/FactorCollection.h" +#include + using namespace std; namespace Moses { -struct DistortionState_traditional : public FFState { +struct DistortionState : public FFState { Range range; int first_gap; - DistortionState_traditional(const Range& wr, int fg) : range(wr), first_gap(fg) {} + bool inSubordinateConjunction; + DistortionState(const Range& wr, int fg, bool subord=false) : range(wr), first_gap(fg), inSubordinateConjunction(subord) {} size_t hash() const { return range.GetEndPos(); } virtual bool operator==(const FFState& other) const { - const DistortionState_traditional& o = - static_cast(other); - return range.GetEndPos() == o.range.GetEndPos(); + const DistortionState& o = + static_cast(other); + return ( (range.GetEndPos() == o.range.GetEndPos()) && (inSubordinateConjunction == o.inSubordinateConjunction) ); } }; @@ -29,11 +34,36 @@ std::vector DistortionScoreProducer::s_staticCol DistortionScoreProducer::DistortionScoreProducer(const std::string &line) : StatefulFeatureFunction(1, line) + , m_useSparse(false) + , m_sparseDistance(false) + , m_sparseSubordinate(false) { s_staticColl.push_back(this); ReadParameters(); } +void DistortionScoreProducer::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "sparse") { + m_useSparse = Scan(value); + } else if (key == "sparse-distance") { + m_sparseDistance = Scan(value); + } else if (key == "sparse-input-factor") { + m_sparseFactorTypeSource = Scan(value); + } else if (key == "sparse-output-factor") { + m_sparseFactorTypeTarget = Scan(value); + } else if (key == "sparse-subordinate") { + std::string subordinateConjunctionTag = Scan(value); + FactorCollection &factorCollection = FactorCollection::Instance(); + m_subordinateConjunctionTagFactor = factorCollection.AddFactor(subordinateConjunctionTag,false); + m_sparseSubordinate = true; + } else if (key == "sparse-subordinate-output-factor") { + m_sparseFactorTypeTargetSubordinate = Scan(value); + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} + const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &input) const { // fake previous translated phrase start and end @@ -44,7 +74,7 @@ const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &in start = 0; end = input.m_frontSpanCoveredLength -1; } - return new DistortionState_traditional( + return new DistortionState( Range(start, end), NOT_FOUND); } @@ -101,17 +131,184 @@ FFState* DistortionScoreProducer::EvaluateWhenApplied( const FFState* prev_state, ScoreComponentCollection* out) const { - const DistortionState_traditional* prev = static_cast(prev_state); + const DistortionState* prev = static_cast(prev_state); + bool subordinateConjunction = prev->inSubordinateConjunction; + + if (m_useSparse) { + int jumpFromPos = prev->range.GetEndPos()+1; + int jumpToPos = hypo.GetCurrSourceWordsRange().GetStartPos(); + size_t distance = std::abs( jumpFromPos - jumpToPos ); + + const Sentence& sentence = static_cast(hypo.GetInput()); + + StringPiece jumpFromSourceFactorPrev; + StringPiece jumpFromSourceFactor; + StringPiece jumpToSourceFactor; + if (jumpFromPos < (int)sentence.GetSize()) { + jumpFromSourceFactor = sentence.GetWord(jumpFromPos).GetFactor(m_sparseFactorTypeSource)->GetString(); + } else { + jumpFromSourceFactor = ""; + } + if (jumpFromPos > 0) { + jumpFromSourceFactorPrev = sentence.GetWord(jumpFromPos-1).GetFactor(m_sparseFactorTypeSource)->GetString(); + } else { + jumpFromSourceFactorPrev = ""; + } + jumpToSourceFactor = sentence.GetWord(jumpToPos).GetFactor(m_sparseFactorTypeSource)->GetString(); + + const TargetPhrase& currTargetPhrase = hypo.GetCurrTargetPhrase(); + StringPiece jumpToTargetFactor = currTargetPhrase.GetWord(0).GetFactor(m_sparseFactorTypeTarget)->GetString(); + + util::StringStream featureName; + + // source factor (start position) + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + featureName << "_SFS_" << jumpFromSourceFactor; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + // source factor (start position minus 1) + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + featureName << "_SFP_" << jumpFromSourceFactorPrev; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + // source factor (end position) + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + featureName << "_SFE_" << jumpToSourceFactor; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + // target factor (end position) + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + featureName << "_TFE_" << jumpToTargetFactor; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + // relative source sentence position + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + size_t relativeSourceSentencePosBin = std::floor( 5 * (float)jumpFromPos / (sentence.GetSize()+1) ); + featureName << "_P_" << relativeSourceSentencePosBin; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + // source sentence length bin + featureName = util::StringStream(); + featureName << m_description << "_"; + if ( jumpToPos > jumpFromPos ) { + featureName << "R"; + } else if ( jumpToPos < jumpFromPos ) { + featureName << "L"; + } else { + featureName << "M"; + } + if (m_sparseDistance) { + featureName << distance; + } + size_t sourceSentenceLengthBin = 3; + if (sentence.GetSize() < 15) { + sourceSentenceLengthBin = 0; + } else if (sentence.GetSize() < 23) { + sourceSentenceLengthBin = 1; + } else if (sentence.GetSize() < 33) { + sourceSentenceLengthBin = 2; + } + featureName << "_SL_" << sourceSentenceLengthBin; + if (m_sparseSubordinate && subordinateConjunction) { + featureName << "_SUBORD"; + } + out->SparsePlusEquals(featureName.str(), 1); + + if (m_sparseSubordinate) { + for (size_t posT=0; posTGetString()[0] == 'V') { + subordinateConjunction = false; + } + }; + } + } + const float distortionScore = CalculateDistortionScore( hypo, prev->range, hypo.GetCurrSourceWordsRange(), prev->first_gap); out->PlusEquals(this, distortionScore); - DistortionState_traditional* res = new DistortionState_traditional( + + DistortionState* state = new DistortionState( hypo.GetCurrSourceWordsRange(), - hypo.GetWordsBitmap().GetFirstGapPos()); - return res; + hypo.GetWordsBitmap().GetFirstGapPos(), + subordinateConjunction); + + return state; } diff --git a/moses/FF/DistortionScoreProducer.h b/moses/FF/DistortionScoreProducer.h index cfe0dc005..d59214df7 100644 --- a/moses/FF/DistortionScoreProducer.h +++ b/moses/FF/DistortionScoreProducer.h @@ -1,16 +1,11 @@ #pragma once -#include #include #include "StatefulFeatureFunction.h" +#include "moses/Range.h" namespace Moses { -class FFState; -class ScoreComponentCollection; -class Hypothesis; -class ChartHypothesis; -class Range; /** Calculates Distortion scores */ @@ -19,6 +14,14 @@ class DistortionScoreProducer : public StatefulFeatureFunction protected: static std::vector s_staticColl; + FactorType m_sparseFactorTypeSource; + FactorType m_sparseFactorTypeTarget; + bool m_useSparse; + bool m_sparseDistance; + bool m_sparseSubordinate; + FactorType m_sparseFactorTypeTargetSubordinate; + const Factor* m_subordinateConjunctionTagFactor; + public: static const std::vector& GetDistortionFeatureFunctions() { return s_staticColl; @@ -26,6 +29,8 @@ public: DistortionScoreProducer(const std::string &line); + void SetParameter(const std::string& key, const std::string& value); + bool IsUseable(const FactorMask &mask) const { return true; } @@ -44,7 +49,7 @@ public: const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection*) const { - throw std::logic_error("DistortionScoreProducer not supported in chart decoder, yet"); + UTIL_THROW(util::Exception, "DIstortion not implemented in chart decoder"); } };