add optional bias feature to GLMU

This commit is contained in:
Eva Hasler 2012-02-17 15:55:59 +00:00
parent ee2ea468df
commit 9a38a7dba9
3 changed files with 35 additions and 18 deletions

View File

@ -10,10 +10,12 @@ namespace Moses
{
GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const vector< FactorType >& inFactors,
const vector< FactorType >& outFactors,
bool ignorePunctuation)
bool ignorePunctuation,
bool biasFeature)
: StatelessFeatureFunction("glm",ScoreProducer::unlimited),
m_sparseProducerWeight(1),
m_ignorePunctuation(ignorePunctuation)
m_ignorePunctuation(ignorePunctuation),
m_biasFeature(biasFeature)
{
std::cerr << "Creating global lexical model unlimited.. ";
@ -22,12 +24,12 @@ GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const vector< FactorTyp
// compile a list of punctuation characters
if (m_ignorePunctuation) {
cerr << "ignoring punctuation";
cerr << "ignoring punctuation.. ";
char punctuation[] = "\"'!?¿·()#_,.:;•&@/\\0123456789~=";
for (size_t i=0; i < sizeof(punctuation)-1; ++i)
m_punctuationHash[punctuation[i]] = 1;
}
cerr << endl;
cerr << "done." << endl;
}
GlobalLexicalModelUnlimited::~GlobalLexicalModelUnlimited(){}
@ -53,13 +55,22 @@ void GlobalLexicalModelUnlimited::Evaluate(const TargetPhrase& targetPhrase, Sco
string targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors
if (m_ignorePunctuation) {
// check if first char is punctuation
char firstChar = targetString.at(0);
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
// check if first char is punctuation
char firstChar = targetString.at(0);
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
}
if (m_biasFeature) {
stringstream feature;
feature << "glm_";
feature << targetString;
feature << "~";
feature << "**BIAS**";
accumulator->SparsePlusEquals(feature.str(), 1);
}
// set< const Word*, WordComparer > alreadyScored; // do not score a word twice
StringHash alreadyScored;
for(size_t inputIndex = 0; inputIndex < input.GetSize(); inputIndex++ ) {
@ -69,8 +80,8 @@ void GlobalLexicalModelUnlimited::Evaluate(const TargetPhrase& targetPhrase, Sco
// check if first char is punctuation
char firstChar = inputString.at(0);
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
if(charIterator != m_punctuationHash.end())
continue;
}
//if ( alreadyScored.find( &inputWord ) == alreadyScored.end() ) {

View File

@ -53,6 +53,7 @@ private:
CharHash m_punctuationHash;
bool m_ignorePunctuation;
bool m_biasFeature;
std::vector< FactorType > m_inputFactors;
std::vector< FactorType > m_outputFactors;
@ -65,7 +66,8 @@ private:
public:
GlobalLexicalModelUnlimited(const std::vector< FactorType >& inFactors,
const std::vector< FactorType >& outFactors,
bool ignorePunctuation);
bool ignorePunctuation,
bool biasFeature);
virtual ~GlobalLexicalModelUnlimited();

View File

@ -973,13 +973,17 @@ bool StaticData::LoadGlobalLexicalModelUnlimited()
for (size_t i = 0; i < weight.size(); i++ ) {
bool ignorePunctuation = false;
bool biasFeature = false;
vector< string > factors;
vector< string > factors_punctuation = Tokenize(modelSpec[i]," ");
vector< string > factors_punct_bias = Tokenize(modelSpec[i]," ");
// read optional punctuation specification
if (factors_punctuation.size() > 0) {
factors = Tokenize(factors_punctuation[0],"-");
ignorePunctuation = Scan<int>(factors_punctuation[1]);
// read optional punctuation and bias specifications
if (factors_punct_bias.size() > 0) {
factors = Tokenize(factors_punct_bias[0],"-");
if (factors_punct_bias.size() >= 2)
ignorePunctuation = Scan<int>(factors_punct_bias[1]);
if (factors_punct_bias.size() >= 3)
biasFeature = Scan<int>(factors_punct_bias[2]);
}
else
factors = Tokenize(modelSpec[i],"-");
@ -990,7 +994,7 @@ bool StaticData::LoadGlobalLexicalModelUnlimited()
}
const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],",");
const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],",");
m_globalLexicalModelsUnlimited.push_back(new GlobalLexicalModelUnlimited(inputFactors, outputFactors, ignorePunctuation));
m_globalLexicalModelsUnlimited.push_back(new GlobalLexicalModelUnlimited(inputFactors, outputFactors, ignorePunctuation, biasFeature));
m_globalLexicalModelsUnlimited[i]->SetSparseProducerWeight(weight[i]);
}
return true;