change context behaviour of word translation model, simplify glm

This commit is contained in:
Eva 2012-03-07 14:04:25 +00:00
parent 97b8616513
commit 19ea65a4d4
4 changed files with 88 additions and 192 deletions

View File

@ -118,8 +118,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool contextExists;
if (!m_unrestricted)
contextExists = m_vocabSource.find( contextString ) != m_vocabSource.end();
if (contextIndex == sourceIndex+1)
contextExists = true; // always add adjacent context words
if (m_unrestricted || contextExists) {
stringstream feature;
@ -156,8 +154,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool sourceTriggerExists = false;
if (!m_unrestricted)
sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end();
if (contextIndex == sourceIndex-1)
sourceTriggerExists = true; // always add adjacent context words
if (m_unrestricted || sourceTriggerExists)
AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString,
@ -182,8 +178,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == targetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || targetTriggerExists)
AddFeature(accumulator, alreadyScored, sourceContext, sourceString,
@ -213,8 +207,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == globalTargetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString,
@ -230,8 +222,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool sourceTriggerExists = false;
if (!m_unrestricted)
sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end();
if (contextIndex == sourceIndex-1)
sourceTriggerExists = true; // always add adjacent context words
if (globalTargetIndex == 0) {
string targetTrigger = "<s>";
@ -248,8 +238,6 @@ FFState* GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, const
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == globalTargetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, alreadyScored, sourceTrigger, sourceString,

View File

@ -1747,8 +1747,9 @@ bool StaticData::LoadWordTranslationFeature()
}
vector<string> tokens = Tokenize(parameters[0]);
if (tokens.size() != 2 && tokens.size() != 3 && tokens.size() != 5) {
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src> <factor-tgt> [context-type] [filename-src filename-tgt]");
if (tokens.size() != 1 && tokens.size() != 4 && tokens.size() != 6) {
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src>-<factor-tgt> "
"[simple source-trigger target-trigger] [filename-src filename-tgt]");
return false;
}
@ -1758,18 +1759,25 @@ bool StaticData::LoadWordTranslationFeature()
}
// set factor
FactorType factorIdSource = Scan<size_t>(tokens[0]);
FactorType factorIdTarget = Scan<size_t>(tokens[1]);
size_t context = 0;
if (tokens.size() >= 3)
context = Scan<size_t>(tokens[2]);
vector <string> factors = Tokenize(tokens[0],"-");
FactorType factorIdSource = Scan<size_t>(factors[0]);
FactorType factorIdTarget = Scan<size_t>(factors[1]);
bool simple = 1;
bool sourceTrigger = 0;
bool targetTrigger = 0;
if (tokens.size() >= 4) {
simple = Scan<size_t>(tokens[1]);
sourceTrigger = Scan<size_t>(tokens[2]);
targetTrigger = Scan<size_t>(tokens[3]);
}
m_wordTranslationFeature = new WordTranslationFeature(factorIdSource,factorIdTarget, context);
m_wordTranslationFeature = new WordTranslationFeature(factorIdSource, factorIdTarget, simple,
sourceTrigger, targetTrigger);
// load word list for restricted feature set
if (tokens.size() == 5) {
string filenameSource = tokens[3];
string filenameTarget = tokens[4];
if (tokens.size() == 6) {
string filenameSource = tokens[5];
string filenameTarget = tokens[6];
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!m_wordTranslationFeature->Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);

View File

@ -69,11 +69,16 @@ FFState* WordTranslationFeature::Evaluate(const Hypothesis& cur_hypo, const FFSt
bool targetExists = m_vocabTarget.find( targetWord ) != m_vocabTarget.end();
// no feature if both words are not in restricted vocabularies
if (m_unrestricted || (sourceExists && targetExists)) {
if (m_simple) {
// construct feature name
stringstream featureName;
featureName << ((sourceExists||m_unrestricted) ? sourceWord : "OTHER");
featureName << "~";
featureName << ((targetExists||m_unrestricted) ? targetWord : "OTHER");
accumulator->PlusEquals(this,featureName.str(),1);
}
if (m_sourceContext) {
// TODO: need to determine position of the source phrase in the global input!!
cerr << "not implemented!" << endl;
exit(0);
size_t globalSourceIndex = input.GetSize() - sourcePhrase.GetSize() + sourceIndex;
size_t globalSourceIndex = cur_hypo.GetCurrSourceWordsRange().GetStartPos() + sourceIndex;
if (globalSourceIndex == 0) {
// add <s> trigger feature for source
stringstream feature;
@ -83,172 +88,76 @@ FFState* WordTranslationFeature::Evaluate(const Hypothesis& cur_hypo, const FFSt
feature << "<s>,";
feature << sourceWord;
accumulator->SparsePlusEquals(feature.str(), 1);
cerr << "feature: " << feature.str() << endl;
}
// add source words to the right of current source word as context
for(size_t contextIndex = globalSourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) {
// range over source words to get context
for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
if (contextIndex == globalSourceIndex) continue;
string sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
bool sourceTriggerExists = false;
if (!m_unrestricted)
sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end();
if (contextIndex == globalSourceIndex+1)
sourceTriggerExists = true; // always add adjacent context words
if (m_unrestricted || sourceTriggerExists) {
stringstream feature;
feature << "wt_";
feature << targetWord;
feature << "~";
feature << sourceWord;
if (contextIndex < globalSourceIndex) {
feature << sourceTrigger;
feature << ",";
feature << sourceWord;
}
else {
feature << sourceWord;
feature << ",";
feature << sourceTrigger;
}
accumulator->SparsePlusEquals(feature.str(), 1);
}
}
}
if (m_targetContext) {
size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex;
cerr << "\n" << sourceWord << "-" << targetWord << endl;
cerr << "hypo size " << cur_hypo.GetSize() << endl;
cerr << "global target index: " << globalTargetIndex << endl;
if (globalTargetIndex == 0) {
// add <s> trigger feature for source
stringstream feature;
feature << "wt_";
feature << "<s>,";
feature << targetWord;
feature << "~";
feature << sourceWord;
accumulator->SparsePlusEquals(feature.str(), 1);
cerr << "feature: " << feature.str() << endl;
}
// range over target words (up to current position) to get context
for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) {
string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString();
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (m_unrestricted || targetTriggerExists) {
stringstream feature;
feature << "wt_";
feature << targetTrigger;
feature << ",";
feature << sourceTrigger;
feature << targetWord;
feature << "~";
feature << sourceWord;
accumulator->SparsePlusEquals(feature.str(), 1);
cerr << "feature: " << feature.str() << endl;
}
}
}
else if (m_biphrase) {
// allow additional discont. triggers on one of the sides, bigram on the other side
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
// TODO: need to determine position of the source phrase in the global input!!
cerr << "not implemented!" << endl;
exit(0);
int globalSourceIndex = input.GetSize() - sourcePhrase.GetSize() + sourceIndex;
// 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram)
string targetContext;
if (globalTargetIndex > 0)
targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetFactor(m_factorTypeTarget)->GetString();
else
targetContext = "<s>";
if (globalSourceIndex == 0) {
string sourceTrigger = "<s>";
AddFeature(accumulator, sourceTrigger, sourceWord,
targetContext, targetWord);
}
else
for(int contextIndex = globalSourceIndex-1; contextIndex >= 0; contextIndex-- ) {
string sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
bool sourceTriggerExists = false;
if (!m_unrestricted)
sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end();
if (contextIndex == globalSourceIndex-1)
sourceTriggerExists = true; // always add adjacent context words
if (m_unrestricted || sourceTriggerExists)
AddFeature(accumulator, sourceTrigger, sourceWord,
targetContext, targetWord);
}
// 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.)
string sourceContext;
if (globalSourceIndex-1 >= 0)
sourceContext = input.GetWord(globalSourceIndex-1).GetFactor(m_factorTypeSource)->GetString();
else
sourceContext = "<s>";
if (globalTargetIndex == 0) {
string targetTrigger = "<s>";
AddFeature(accumulator, sourceContext, sourceWord,
targetTrigger, targetWord);
}
else
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetFactor(m_factorTypeTarget)->GetString();
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == targetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || targetTriggerExists)
AddFeature(accumulator, sourceContext, sourceWord,
targetTrigger, targetWord);
}
}
else if (m_bitrigger) {
// allow additional discont. triggers on both sides
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
// TODO: need to determine position of the source phrase in the global input!!
cerr << "not implemented!" << endl;
exit(0);
int globalSourceIndex =input.GetSize() - sourcePhrase.GetSize() + sourceIndex;
if (globalSourceIndex == 0) {
string sourceTrigger = "<s>";
bool sourceTriggerExists = true;
if (globalTargetIndex == 0) {
string targetTrigger = "<s>";
bool targetTriggerExists = true;
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, sourceTrigger, sourceWord, targetTrigger, targetWord);
}
else {
// iterate backwards over target
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetFactor(m_factorTypeTarget)->GetString();
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == globalTargetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, sourceTrigger, sourceWord, targetTrigger, targetWord);
}
}
}
// iterate over both source and target
else {
// iterate backwards over source
for(int contextIndex = globalSourceIndex-1; contextIndex >= 0; contextIndex-- ) {
string sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
bool sourceTriggerExists = false;
if (!m_unrestricted)
sourceTriggerExists = m_vocabSource.find( sourceTrigger ) != m_vocabSource.end();
if (contextIndex == globalSourceIndex-1)
sourceTriggerExists = true; // always add adjacent context words
if (globalTargetIndex == 0) {
string targetTrigger = "<s>";
bool targetTriggerExists = true;
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, sourceTrigger, sourceWord, targetTrigger, targetWord);
}
else {
// iterate backwards over target
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
string targetTrigger = cur_hypo.GetWord(globalContextIndex).GetFactor(m_factorTypeTarget)->GetString();
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (globalContextIndex == globalTargetIndex-1)
targetTriggerExists = true; // always add adjacent context words
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
AddFeature(accumulator, sourceTrigger, sourceWord, targetTrigger, targetWord);
}
}
}
}
}
else {
// construct feature name
stringstream featureName;
featureName << ((sourceExists||m_unrestricted) ? sourceWord : "OTHER");
featureName << "~";
featureName << ((targetExists||m_unrestricted) ? targetWord : "OTHER");
accumulator->PlusEquals(this,featureName.str(),1);
}
}
}
return new DummyState();
return new DummyState();
}
void WordTranslationFeature::AddFeature(ScoreComponentCollection* accumulator, string sourceTrigger,
@ -263,6 +172,7 @@ void WordTranslationFeature::AddFeature(ScoreComponentCollection* accumulator, s
feature << ",";
feature << sourceWord;
accumulator->SparsePlusEquals(feature.str(), 1);
cerr << "feature: " << feature.str() << endl;
}
}

View File

@ -39,36 +39,26 @@ private:
FactorType m_factorTypeSource;
FactorType m_factorTypeTarget;
bool m_unrestricted;
bool m_simple;
bool m_sourceContext;
bool m_biphrase;
bool m_bitrigger;
bool m_targetContext;
public:
WordTranslationFeature(FactorType factorTypeSource = 0, FactorType factorTypeTarget = 0, size_t context = 0):
WordTranslationFeature(FactorType factorTypeSource, FactorType factorTypeTarget,
bool simple, bool sourceContext, bool targetContext):
// StatelessFeatureFunction("wt", ScoreProducer::unlimited),
StatefulFeatureFunction("wt", ScoreProducer::unlimited),
m_factorTypeSource(factorTypeSource),
m_factorTypeTarget(factorTypeTarget),
m_simple(simple),
m_sourceContext(sourceContext),
m_targetContext(targetContext),
m_unrestricted(true)
{
std::cerr << "Creating word translation feature.. ";
m_sourceContext = false;
m_biphrase = false;
m_bitrigger = false;
switch(context) {
case 1:
m_sourceContext = true;
std::cerr << "using source context.. ";
break;
case 2:
m_biphrase = true;
std::cerr << "using biphrases.. ";
break;
case 3:
m_bitrigger = true;
std::cerr << "using bitriggers.. ";
break;
}
if (m_simple == 1) std::cerr << "using simple word translations.. ";
if (m_sourceContext == 1) std::cerr << "using source context.. ";
if (m_targetContext == 1) std::cerr << "using target context.. ";
std::cerr << "done." << std::endl;
}