Sparse distortion features

in the manner of:

Spence Green, Michel Galley, Christopher D. Manning. 2010.  Improved
Models of Distortion Cost for Statistical Machine Translation. In NAACL
2010.
This commit is contained in:
Matthias Huck 2016-03-17 16:10:49 +00:00
parent 1ac4ca5735
commit bbf8a615f2
2 changed files with 219 additions and 17 deletions

View File

@ -1,26 +1,31 @@
#include "DistortionScoreProducer.h"
#include "FFState.h"
#include "moses/InputPath.h"
#include "moses/Range.h"
#include "moses/StaticData.h"
#include "moses/Hypothesis.h"
#include "moses/Manager.h"
#include "moses/FactorCollection.h"
#include <cmath>
using namespace std;
namespace Moses
{
struct DistortionState_traditional : public FFState {
struct DistortionState : public FFState {
Range range;
int first_gap;
DistortionState_traditional(const Range& wr, int fg) : range(wr), first_gap(fg) {}
bool inSubordinateConjunction;
DistortionState(const Range& wr, int fg, bool subord=false) : range(wr), first_gap(fg), inSubordinateConjunction(subord) {}
size_t hash() const {
return range.GetEndPos();
}
virtual bool operator==(const FFState& other) const {
const DistortionState_traditional& o =
static_cast<const DistortionState_traditional&>(other);
return range.GetEndPos() == o.range.GetEndPos();
const DistortionState& o =
static_cast<const DistortionState&>(other);
return ( (range.GetEndPos() == o.range.GetEndPos()) && (inSubordinateConjunction == o.inSubordinateConjunction) );
}
};
@ -29,11 +34,36 @@ std::vector<const DistortionScoreProducer*> DistortionScoreProducer::s_staticCol
DistortionScoreProducer::DistortionScoreProducer(const std::string &line)
: StatefulFeatureFunction(1, line)
, m_useSparse(false)
, m_sparseDistance(false)
, m_sparseSubordinate(false)
{
s_staticColl.push_back(this);
ReadParameters();
}
void DistortionScoreProducer::SetParameter(const std::string& key, const std::string& value)
{
if (key == "sparse") {
m_useSparse = Scan<bool>(value);
} else if (key == "sparse-distance") {
m_sparseDistance = Scan<bool>(value);
} else if (key == "sparse-input-factor") {
m_sparseFactorTypeSource = Scan<FactorType>(value);
} else if (key == "sparse-output-factor") {
m_sparseFactorTypeTarget = Scan<FactorType>(value);
} else if (key == "sparse-subordinate") {
std::string subordinateConjunctionTag = Scan<std::string>(value);
FactorCollection &factorCollection = FactorCollection::Instance();
m_subordinateConjunctionTagFactor = factorCollection.AddFactor(subordinateConjunctionTag,false);
m_sparseSubordinate = true;
} else if (key == "sparse-subordinate-output-factor") {
m_sparseFactorTypeTargetSubordinate = Scan<FactorType>(value);
} else {
StatefulFeatureFunction::SetParameter(key, value);
}
}
const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &input) const
{
// fake previous translated phrase start and end
@ -44,7 +74,7 @@ const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &in
start = 0;
end = input.m_frontSpanCoveredLength -1;
}
return new DistortionState_traditional(
return new DistortionState(
Range(start, end),
NOT_FOUND);
}
@ -101,17 +131,184 @@ FFState* DistortionScoreProducer::EvaluateWhenApplied(
const FFState* prev_state,
ScoreComponentCollection* out) const
{
const DistortionState_traditional* prev = static_cast<const DistortionState_traditional*>(prev_state);
const DistortionState* prev = static_cast<const DistortionState*>(prev_state);
bool subordinateConjunction = prev->inSubordinateConjunction;
if (m_useSparse) {
int jumpFromPos = prev->range.GetEndPos()+1;
int jumpToPos = hypo.GetCurrSourceWordsRange().GetStartPos();
size_t distance = std::abs( jumpFromPos - jumpToPos );
const Sentence& sentence = static_cast<const Sentence&>(hypo.GetInput());
StringPiece jumpFromSourceFactorPrev;
StringPiece jumpFromSourceFactor;
StringPiece jumpToSourceFactor;
if (jumpFromPos < (int)sentence.GetSize()) {
jumpFromSourceFactor = sentence.GetWord(jumpFromPos).GetFactor(m_sparseFactorTypeSource)->GetString();
} else {
jumpFromSourceFactor = "</s>";
}
if (jumpFromPos > 0) {
jumpFromSourceFactorPrev = sentence.GetWord(jumpFromPos-1).GetFactor(m_sparseFactorTypeSource)->GetString();
} else {
jumpFromSourceFactorPrev = "<s>";
}
jumpToSourceFactor = sentence.GetWord(jumpToPos).GetFactor(m_sparseFactorTypeSource)->GetString();
const TargetPhrase& currTargetPhrase = hypo.GetCurrTargetPhrase();
StringPiece jumpToTargetFactor = currTargetPhrase.GetWord(0).GetFactor(m_sparseFactorTypeTarget)->GetString();
util::StringStream featureName;
// source factor (start position)
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
featureName << "_SFS_" << jumpFromSourceFactor;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
// source factor (start position minus 1)
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
featureName << "_SFP_" << jumpFromSourceFactorPrev;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
// source factor (end position)
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
featureName << "_SFE_" << jumpToSourceFactor;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
// target factor (end position)
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
featureName << "_TFE_" << jumpToTargetFactor;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
// relative source sentence position
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
size_t relativeSourceSentencePosBin = std::floor( 5 * (float)jumpFromPos / (sentence.GetSize()+1) );
featureName << "_P_" << relativeSourceSentencePosBin;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
// source sentence length bin
featureName = util::StringStream();
featureName << m_description << "_";
if ( jumpToPos > jumpFromPos ) {
featureName << "R";
} else if ( jumpToPos < jumpFromPos ) {
featureName << "L";
} else {
featureName << "M";
}
if (m_sparseDistance) {
featureName << distance;
}
size_t sourceSentenceLengthBin = 3;
if (sentence.GetSize() < 15) {
sourceSentenceLengthBin = 0;
} else if (sentence.GetSize() < 23) {
sourceSentenceLengthBin = 1;
} else if (sentence.GetSize() < 33) {
sourceSentenceLengthBin = 2;
}
featureName << "_SL_" << sourceSentenceLengthBin;
if (m_sparseSubordinate && subordinateConjunction) {
featureName << "_SUBORD";
}
out->SparsePlusEquals(featureName.str(), 1);
if (m_sparseSubordinate) {
for (size_t posT=0; posT<currTargetPhrase.GetSize(); ++posT) {
const Word &wordT = currTargetPhrase.GetWord(posT);
if (wordT[m_sparseFactorTypeTargetSubordinate] == m_subordinateConjunctionTagFactor) {
subordinateConjunction = true;
} else if (wordT[m_sparseFactorTypeTargetSubordinate]->GetString()[0] == 'V') {
subordinateConjunction = false;
}
};
}
}
const float distortionScore = CalculateDistortionScore(
hypo,
prev->range,
hypo.GetCurrSourceWordsRange(),
prev->first_gap);
out->PlusEquals(this, distortionScore);
DistortionState_traditional* res = new DistortionState_traditional(
DistortionState* state = new DistortionState(
hypo.GetCurrSourceWordsRange(),
hypo.GetWordsBitmap().GetFirstGapPos());
return res;
hypo.GetWordsBitmap().GetFirstGapPos(),
subordinateConjunction);
return state;
}

View File

@ -1,16 +1,11 @@
#pragma once
#include <stdexcept>
#include <string>
#include "StatefulFeatureFunction.h"
#include "moses/Range.h"
namespace Moses
{
class FFState;
class ScoreComponentCollection;
class Hypothesis;
class ChartHypothesis;
class Range;
/** Calculates Distortion scores
*/
@ -19,6 +14,14 @@ class DistortionScoreProducer : public StatefulFeatureFunction
protected:
static std::vector<const DistortionScoreProducer*> s_staticColl;
FactorType m_sparseFactorTypeSource;
FactorType m_sparseFactorTypeTarget;
bool m_useSparse;
bool m_sparseDistance;
bool m_sparseSubordinate;
FactorType m_sparseFactorTypeTargetSubordinate;
const Factor* m_subordinateConjunctionTagFactor;
public:
static const std::vector<const DistortionScoreProducer*>& GetDistortionFeatureFunctions() {
return s_staticColl;
@ -26,6 +29,8 @@ public:
DistortionScoreProducer(const std::string &line);
void SetParameter(const std::string& key, const std::string& value);
bool IsUseable(const FactorMask &mask) const {
return true;
}
@ -44,7 +49,7 @@ public:
const ChartHypothesis& /* cur_hypo */,
int /* featureID - used to index the state in the previous hypotheses */,
ScoreComponentCollection*) const {
throw std::logic_error("DistortionScoreProducer not supported in chart decoder, yet");
UTIL_THROW(util::Exception, "DIstortion not implemented in chart decoder");
}
};