From 9a38a7dba975a0ac2d287bb5ab6161182499730f Mon Sep 17 00:00:00 2001 From: Eva Hasler Date: Fri, 17 Feb 2012 15:55:59 +0000 Subject: [PATCH] add optional bias feature to GLMU --- moses/src/GlobalLexicalModelUnlimited.cpp | 33 +++++++++++++++-------- moses/src/GlobalLexicalModelUnlimited.h | 4 ++- moses/src/StaticData.cpp | 16 ++++++----- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/moses/src/GlobalLexicalModelUnlimited.cpp b/moses/src/GlobalLexicalModelUnlimited.cpp index ce548f89c..14e9d821d 100644 --- a/moses/src/GlobalLexicalModelUnlimited.cpp +++ b/moses/src/GlobalLexicalModelUnlimited.cpp @@ -10,10 +10,12 @@ namespace Moses { GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const vector< FactorType >& inFactors, const vector< FactorType >& outFactors, - bool ignorePunctuation) + bool ignorePunctuation, + bool biasFeature) : StatelessFeatureFunction("glm",ScoreProducer::unlimited), m_sparseProducerWeight(1), - m_ignorePunctuation(ignorePunctuation) + m_ignorePunctuation(ignorePunctuation), + m_biasFeature(biasFeature) { std::cerr << "Creating global lexical model unlimited.. "; @@ -22,12 +24,12 @@ GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const vector< FactorTyp // compile a list of punctuation characters if (m_ignorePunctuation) { - cerr << "ignoring punctuation"; + cerr << "ignoring punctuation.. "; char punctuation[] = "\"'!?¿·()#_,.:;•&@‑/\\0123456789~="; for (size_t i=0; i < sizeof(punctuation)-1; ++i) m_punctuationHash[punctuation[i]] = 1; } - cerr << endl; + cerr << "done." << endl; } GlobalLexicalModelUnlimited::~GlobalLexicalModelUnlimited(){} @@ -53,13 +55,22 @@ void GlobalLexicalModelUnlimited::Evaluate(const TargetPhrase& targetPhrase, Sco string targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors if (m_ignorePunctuation) { - // check if first char is punctuation - char firstChar = targetString.at(0); - CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); - if(charIterator != m_punctuationHash.end()) - continue; + // check if first char is punctuation + char firstChar = targetString.at(0); + CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); + if(charIterator != m_punctuationHash.end()) + continue; } + if (m_biasFeature) { + stringstream feature; + feature << "glm_"; + feature << targetString; + feature << "~"; + feature << "**BIAS**"; + accumulator->SparsePlusEquals(feature.str(), 1); + } + // set< const Word*, WordComparer > alreadyScored; // do not score a word twice StringHash alreadyScored; for(size_t inputIndex = 0; inputIndex < input.GetSize(); inputIndex++ ) { @@ -69,8 +80,8 @@ void GlobalLexicalModelUnlimited::Evaluate(const TargetPhrase& targetPhrase, Sco // check if first char is punctuation char firstChar = inputString.at(0); CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); - if(charIterator != m_punctuationHash.end()) - continue; + if(charIterator != m_punctuationHash.end()) + continue; } //if ( alreadyScored.find( &inputWord ) == alreadyScored.end() ) { diff --git a/moses/src/GlobalLexicalModelUnlimited.h b/moses/src/GlobalLexicalModelUnlimited.h index 19476c679..e867dc116 100644 --- a/moses/src/GlobalLexicalModelUnlimited.h +++ b/moses/src/GlobalLexicalModelUnlimited.h @@ -53,6 +53,7 @@ private: CharHash m_punctuationHash; bool m_ignorePunctuation; + bool m_biasFeature; std::vector< FactorType > m_inputFactors; std::vector< FactorType > m_outputFactors; @@ -65,7 +66,8 @@ private: public: GlobalLexicalModelUnlimited(const std::vector< FactorType >& inFactors, const std::vector< FactorType >& outFactors, - bool ignorePunctuation); + bool ignorePunctuation, + bool biasFeature); virtual ~GlobalLexicalModelUnlimited(); diff --git a/moses/src/StaticData.cpp b/moses/src/StaticData.cpp index af5861b22..6f6636d2f 100644 --- a/moses/src/StaticData.cpp +++ b/moses/src/StaticData.cpp @@ -973,13 +973,17 @@ bool StaticData::LoadGlobalLexicalModelUnlimited() for (size_t i = 0; i < weight.size(); i++ ) { bool ignorePunctuation = false; + bool biasFeature = false; vector< string > factors; - vector< string > factors_punctuation = Tokenize(modelSpec[i]," "); + vector< string > factors_punct_bias = Tokenize(modelSpec[i]," "); - // read optional punctuation specification - if (factors_punctuation.size() > 0) { - factors = Tokenize(factors_punctuation[0],"-"); - ignorePunctuation = Scan(factors_punctuation[1]); + // read optional punctuation and bias specifications + if (factors_punct_bias.size() > 0) { + factors = Tokenize(factors_punct_bias[0],"-"); + if (factors_punct_bias.size() >= 2) + ignorePunctuation = Scan(factors_punct_bias[1]); + if (factors_punct_bias.size() >= 3) + biasFeature = Scan(factors_punct_bias[2]); } else factors = Tokenize(modelSpec[i],"-"); @@ -990,7 +994,7 @@ bool StaticData::LoadGlobalLexicalModelUnlimited() } const vector inputFactors = Tokenize(factors[0],","); const vector outputFactors = Tokenize(factors[1],","); - m_globalLexicalModelsUnlimited.push_back(new GlobalLexicalModelUnlimited(inputFactors, outputFactors, ignorePunctuation)); + m_globalLexicalModelsUnlimited.push_back(new GlobalLexicalModelUnlimited(inputFactors, outputFactors, ignorePunctuation, biasFeature)); m_globalLexicalModelsUnlimited[i]->SetSparseProducerWeight(weight[i]); } return true;