move code from LoadWordTranslationFeature() to constructor in class WordTranslationFeature

This commit is contained in:
Hieu Hoang 2013-01-02 21:26:56 +00:00
parent b71d6c9b24
commit 2f6b445bf6
5 changed files with 108 additions and 122 deletions

View File

@ -151,7 +151,6 @@ Parameter::Parameter()
AddParam("unknown-lhs", "file containing target lhs of unknown words. 1 per line: LHS prob");
AddParam("show-weights", "print feature weights and exit");
AddParam("start-translation-id", "Id of 1st input. Default = 0");
AddParam("text-type", "should be one of dev/devtest/test, used for domain adaptation features");
AddParam("output-unknowns", "Output the unknown (OOV) words to the given file, one line per sentence");
// Compact phrase table and reordering table.
@ -184,6 +183,7 @@ Parameter::Parameter()
AddParam("weight-w", "w", "DEPRECATED. DO NOT USE. weight for word penalty");
AddParam("weight-u", "u", "DEPRECATED. DO NOT USE. weight for unknown word penalty");
AddParam("weight-e", "e", "DEPRECATED. DO NOT USE. weight for word deletion");
//AddParam("text-type", "DEPRECATED. DO NOT USE. should be one of dev/devtest/test, used for domain adaptation features");
AddParam("weight-file", "wf", "feature weights file. Do *not* put weights for 'core' features in here - they go in moses.ini");

View File

@ -585,6 +585,11 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
//SetWeights(model, weights);
}
else if (feature == "WordTranslationFeature") {
WordTranslationFeature *model = new WordTranslationFeature(line);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
//SetWeights(model, weights);
}
}
@ -602,7 +607,6 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
if (!LoadReferences()) return false;
if (!LoadDiscrimLMFeature()) return false;
if (!LoadPhrasePairFeature()) return false;
if (!LoadWordTranslationFeature()) return false;
// report individual sparse features in n-best list
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
@ -1424,95 +1428,6 @@ bool StaticData::LoadPhrasePairFeature()
return true;
}
bool StaticData::LoadWordTranslationFeature()
{
const vector<string> &parameters = m_parameter->GetParam("word-translation-feature");
if (parameters.empty())
return true;
const vector<float> &weight = m_parameter->GetWeights("WordPenalty", 0);
CHECK(weight.size() == 1);
m_needAlignmentInfo = true;
for (size_t i=0; i<parameters.size(); ++i) {
vector<string> tokens = Tokenize(parameters[i]);
if (tokens.size() != 1 && !(tokens.size() >= 4 && tokens.size() <= 8)) {
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src>-<factor-tgt> "
"[simple source-trigger target-trigger] [ignore-punctuation] [domain-trigger] [filename-src] [filename-tgt]");
return false;
}
// set factor
vector <string> factors = Tokenize(tokens[0],"-");
FactorType factorIdSource = Scan<size_t>(factors[0]);
FactorType factorIdTarget = Scan<size_t>(factors[1]);
bool simple = true, sourceTrigger = false, targetTrigger = false, ignorePunctuation = false, domainTrigger = false;
if (tokens.size() >= 4) {
simple = Scan<size_t>(tokens[1]);
sourceTrigger = Scan<size_t>(tokens[2]);
targetTrigger = Scan<size_t>(tokens[3]);
}
if (tokens.size() >= 5) {
ignorePunctuation = Scan<size_t>(tokens[4]);
}
if (tokens.size() >= 6) {
domainTrigger = Scan<size_t>(tokens[5]);
}
WordTranslationFeature *wordTranslationFeature = new WordTranslationFeature(factorIdSource, factorIdTarget, simple,
sourceTrigger, targetTrigger, ignorePunctuation, domainTrigger);
wordTranslationFeature->SetSparseProducerWeight(weight[i]);
// load word list for restricted feature set
if (tokens.size() == 7) {
string filenameSource = tokens[6];
if (domainTrigger) {
const vector<string> &texttype = m_parameter->GetParam("text-type");
if (texttype.size() != 1) {
UserMessage::Add("Need texttype to load dictionary for domain triggers.");
return false;
}
stringstream filename(filenameSource + "." + texttype[0]);
filenameSource = filename.str();
cerr << "loading word translation term list from " << filenameSource << endl;
}
else {
cerr << "loading word translation word lists from " << filenameSource << endl;
}
if (!wordTranslationFeature->Load(filenameSource, "")) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource);
return false;
}
} // if (tokens.size() == 7)
else if (tokens.size() == 8) {
string filenameSource = tokens[6];
string filenameTarget = tokens[7];
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!wordTranslationFeature->Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
return false;
}
} //else if (tokens.size() == 8) {
// TODO not sure about this
if (weight[0] != 1) {
AddSparseProducer(wordTranslationFeature);
cerr << "wt sparse producer weight: " << weight[0] << endl;
if (m_mira)
m_metaFeatureProducer = new MetaFeatureProducer("wt");
}
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
wordTranslationFeature->SetSparseFeatureReporting();
}
} // for (size_t i=0; i<parameters.size(); ++i)
return true;
}
const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const
{
std::pair<size_t, Phrase> key(decodeGraph.GetPosition(), sourcePhrase);

View File

@ -252,7 +252,6 @@ protected:
bool LoadReferences();
bool LoadDiscrimLMFeature();
bool LoadPhrasePairFeature();
bool LoadWordTranslationFeature();
void ReduceTransOptCache() const;
bool m_continuePartialTranslation;

View File

@ -7,12 +7,113 @@
#include "ChartHypothesis.h"
#include "ScoreComponentCollection.h"
#include "TranslationOption.h"
#include "UserMessage.h"
#include <boost/algorithm/string.hpp>
namespace Moses {
using namespace std;
WordTranslationFeature::WordTranslationFeature(const std::string &line)
:StatelessFeatureFunction("WordTranslationFeature", ScoreProducer::unlimited)
{
std::cerr << "Initializing word translation feature.. " << endl;
vector<string> tokens = Tokenize(line);
//CHECK(tokens[0] == m_description);
if (tokens.size() != 1 && !(tokens.size() >= 4 && tokens.size() <= 9)) {
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src>-<factor-tgt> "
"[simple source-trigger target-trigger] [ignore-punctuation] [domain-trigger] [filename-src] [filename-tgt] [text-type]");
//return false;
}
// set factor
vector <string> factors = Tokenize(tokens[1],"-");
m_factorTypeSource = Scan<FactorType>(factors[0]);
m_factorTypeTarget = Scan<FactorType>(factors[1]);
m_unrestricted = true;
m_sparseProducerWeight = 1;
m_simple = true;
m_sourceContext = false;
m_targetContext = false;
m_ignorePunctuation = false;
m_domainTrigger = false;
if (tokens.size() >= 5) {
m_simple = Scan<size_t>(tokens[2]);
m_sourceContext = Scan<size_t>(tokens[3]);
m_targetContext = Scan<size_t>(tokens[4]);
}
if (tokens.size() >= 6) {
m_ignorePunctuation = Scan<size_t>(tokens[5]);
}
if (tokens.size() >= 7) {
m_domainTrigger = Scan<size_t>(tokens[6]);
}
if (m_simple == 1) std::cerr << "using simple word translations.. ";
if (m_sourceContext == 1) std::cerr << "using source context.. ";
if (m_targetContext == 1) std::cerr << "using target context.. ";
if (m_domainTrigger == 1) std::cerr << "using domain triggers.. ";
// compile a list of punctuation characters
if (m_ignorePunctuation) {
std::cerr << "ignoring punctuation for triggers.. ";
char punctuation[] = "\"'!?¿·()#_,.:;•&@/\\0123456789~=";
for (size_t i=0; i < sizeof(punctuation)-1; ++i) {
m_punctuationHash[punctuation[i]] = 1;
}
}
std::cerr << "done." << std::endl;
// load word list for restricted feature set
if (tokens.size() == 9) {
string filenameSource = tokens[7];
if (m_domainTrigger) {
const string &texttype = tokens[8];
stringstream filename(filenameSource + "." + texttype);
filenameSource = filename.str();
cerr << "loading word translation term list from " << filenameSource << endl;
}
else {
cerr << "loading word translation word lists from " << filenameSource << endl;
}
if (!Load(filenameSource, "")) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource);
//return false;
}
} // if (tokens.size() == 7)
else if (tokens.size() == 10) {
// TODO need to change this
string filenameSource = tokens[7];
string filenameTarget = tokens[8];
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
//return false;
}
} //else if (tokens.size() == 8) {
// TODO not sure about this
/*
if (weight[0] != 1) {
AddSparseProducer(wordTranslationFeature);
cerr << "wt sparse producer weight: " << weight[0] << endl;
if (m_mira)
m_metaFeatureProducer = new MetaFeatureProducer("wt");
}
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
wordTranslationFeature->SetSparseFeatureReporting();
}
*/
}
bool WordTranslationFeature::Load(const std::string &filePathSource, const std::string &filePathTarget)
{
if (m_domainTrigger) {

View File

@ -36,36 +36,7 @@ private:
CharHash m_punctuationHash;
public:
WordTranslationFeature(FactorType factorTypeSource, FactorType factorTypeTarget,
bool simple, bool sourceContext, bool targetContext, bool ignorePunctuation,
bool domainTrigger):
StatelessFeatureFunction("WordTranslationFeature", ScoreProducer::unlimited),
m_factorTypeSource(factorTypeSource),
m_factorTypeTarget(factorTypeTarget),
m_unrestricted(true),
m_simple(simple),
m_sourceContext(sourceContext),
m_targetContext(targetContext),
m_domainTrigger(domainTrigger),
m_sparseProducerWeight(1),
m_ignorePunctuation(ignorePunctuation)
{
std::cerr << "Initializing word translation feature.. ";
if (m_simple == 1) std::cerr << "using simple word translations.. ";
if (m_sourceContext == 1) std::cerr << "using source context.. ";
if (m_targetContext == 1) std::cerr << "using target context.. ";
if (m_domainTrigger == 1) std::cerr << "using domain triggers.. ";
// compile a list of punctuation characters
if (m_ignorePunctuation) {
std::cerr << "ignoring punctuation for triggers.. ";
char punctuation[] = "\"'!?¿·()#_,.:;•&@/\\0123456789~=";
for (size_t i=0; i < sizeof(punctuation)-1; ++i)
m_punctuationHash[punctuation[i]] = 1;
}
std::cerr << "done." << std::endl;
}
WordTranslationFeature(const std::string &line);
bool Load(const std::string &filePathSource, const std::string &filePathTarget);