mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 21:42:19 +03:00
move code from LoadWordTranslationFeature() to constructor in class WordTranslationFeature
This commit is contained in:
parent
b71d6c9b24
commit
2f6b445bf6
@ -151,7 +151,6 @@ Parameter::Parameter()
|
||||
AddParam("unknown-lhs", "file containing target lhs of unknown words. 1 per line: LHS prob");
|
||||
AddParam("show-weights", "print feature weights and exit");
|
||||
AddParam("start-translation-id", "Id of 1st input. Default = 0");
|
||||
AddParam("text-type", "should be one of dev/devtest/test, used for domain adaptation features");
|
||||
AddParam("output-unknowns", "Output the unknown (OOV) words to the given file, one line per sentence");
|
||||
|
||||
// Compact phrase table and reordering table.
|
||||
@ -184,6 +183,7 @@ Parameter::Parameter()
|
||||
AddParam("weight-w", "w", "DEPRECATED. DO NOT USE. weight for word penalty");
|
||||
AddParam("weight-u", "u", "DEPRECATED. DO NOT USE. weight for unknown word penalty");
|
||||
AddParam("weight-e", "e", "DEPRECATED. DO NOT USE. weight for word deletion");
|
||||
//AddParam("text-type", "DEPRECATED. DO NOT USE. should be one of dev/devtest/test, used for domain adaptation features");
|
||||
|
||||
AddParam("weight-file", "wf", "feature weights file. Do *not* put weights for 'core' features in here - they go in moses.ini");
|
||||
|
||||
|
@ -585,6 +585,11 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
|
||||
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
|
||||
//SetWeights(model, weights);
|
||||
}
|
||||
else if (feature == "WordTranslationFeature") {
|
||||
WordTranslationFeature *model = new WordTranslationFeature(line);
|
||||
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
|
||||
//SetWeights(model, weights);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -602,7 +607,6 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
|
||||
if (!LoadReferences()) return false;
|
||||
if (!LoadDiscrimLMFeature()) return false;
|
||||
if (!LoadPhrasePairFeature()) return false;
|
||||
if (!LoadWordTranslationFeature()) return false;
|
||||
|
||||
// report individual sparse features in n-best list
|
||||
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
|
||||
@ -1424,95 +1428,6 @@ bool StaticData::LoadPhrasePairFeature()
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StaticData::LoadWordTranslationFeature()
|
||||
{
|
||||
const vector<string> ¶meters = m_parameter->GetParam("word-translation-feature");
|
||||
if (parameters.empty())
|
||||
return true;
|
||||
|
||||
const vector<float> &weight = m_parameter->GetWeights("WordPenalty", 0);
|
||||
CHECK(weight.size() == 1);
|
||||
|
||||
m_needAlignmentInfo = true;
|
||||
|
||||
for (size_t i=0; i<parameters.size(); ++i) {
|
||||
vector<string> tokens = Tokenize(parameters[i]);
|
||||
if (tokens.size() != 1 && !(tokens.size() >= 4 && tokens.size() <= 8)) {
|
||||
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src>-<factor-tgt> "
|
||||
"[simple source-trigger target-trigger] [ignore-punctuation] [domain-trigger] [filename-src] [filename-tgt]");
|
||||
return false;
|
||||
}
|
||||
|
||||
// set factor
|
||||
vector <string> factors = Tokenize(tokens[0],"-");
|
||||
FactorType factorIdSource = Scan<size_t>(factors[0]);
|
||||
FactorType factorIdTarget = Scan<size_t>(factors[1]);
|
||||
|
||||
bool simple = true, sourceTrigger = false, targetTrigger = false, ignorePunctuation = false, domainTrigger = false;
|
||||
if (tokens.size() >= 4) {
|
||||
simple = Scan<size_t>(tokens[1]);
|
||||
sourceTrigger = Scan<size_t>(tokens[2]);
|
||||
targetTrigger = Scan<size_t>(tokens[3]);
|
||||
}
|
||||
if (tokens.size() >= 5) {
|
||||
ignorePunctuation = Scan<size_t>(tokens[4]);
|
||||
}
|
||||
|
||||
if (tokens.size() >= 6) {
|
||||
domainTrigger = Scan<size_t>(tokens[5]);
|
||||
}
|
||||
|
||||
WordTranslationFeature *wordTranslationFeature = new WordTranslationFeature(factorIdSource, factorIdTarget, simple,
|
||||
sourceTrigger, targetTrigger, ignorePunctuation, domainTrigger);
|
||||
wordTranslationFeature->SetSparseProducerWeight(weight[i]);
|
||||
|
||||
// load word list for restricted feature set
|
||||
if (tokens.size() == 7) {
|
||||
string filenameSource = tokens[6];
|
||||
if (domainTrigger) {
|
||||
const vector<string> &texttype = m_parameter->GetParam("text-type");
|
||||
if (texttype.size() != 1) {
|
||||
UserMessage::Add("Need texttype to load dictionary for domain triggers.");
|
||||
return false;
|
||||
}
|
||||
stringstream filename(filenameSource + "." + texttype[0]);
|
||||
filenameSource = filename.str();
|
||||
cerr << "loading word translation term list from " << filenameSource << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "loading word translation word lists from " << filenameSource << endl;
|
||||
}
|
||||
if (!wordTranslationFeature->Load(filenameSource, "")) {
|
||||
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource);
|
||||
return false;
|
||||
}
|
||||
} // if (tokens.size() == 7)
|
||||
else if (tokens.size() == 8) {
|
||||
string filenameSource = tokens[6];
|
||||
string filenameTarget = tokens[7];
|
||||
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
|
||||
if (!wordTranslationFeature->Load(filenameSource, filenameTarget)) {
|
||||
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
|
||||
return false;
|
||||
}
|
||||
} //else if (tokens.size() == 8) {
|
||||
|
||||
// TODO not sure about this
|
||||
if (weight[0] != 1) {
|
||||
AddSparseProducer(wordTranslationFeature);
|
||||
cerr << "wt sparse producer weight: " << weight[0] << endl;
|
||||
if (m_mira)
|
||||
m_metaFeatureProducer = new MetaFeatureProducer("wt");
|
||||
}
|
||||
|
||||
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
|
||||
wordTranslationFeature->SetSparseFeatureReporting();
|
||||
}
|
||||
} // for (size_t i=0; i<parameters.size(); ++i)
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const
|
||||
{
|
||||
std::pair<size_t, Phrase> key(decodeGraph.GetPosition(), sourcePhrase);
|
||||
|
@ -252,7 +252,6 @@ protected:
|
||||
bool LoadReferences();
|
||||
bool LoadDiscrimLMFeature();
|
||||
bool LoadPhrasePairFeature();
|
||||
bool LoadWordTranslationFeature();
|
||||
|
||||
void ReduceTransOptCache() const;
|
||||
bool m_continuePartialTranslation;
|
||||
|
@ -7,12 +7,113 @@
|
||||
#include "ChartHypothesis.h"
|
||||
#include "ScoreComponentCollection.h"
|
||||
#include "TranslationOption.h"
|
||||
#include "UserMessage.h"
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
namespace Moses {
|
||||
|
||||
using namespace std;
|
||||
|
||||
WordTranslationFeature::WordTranslationFeature(const std::string &line)
|
||||
:StatelessFeatureFunction("WordTranslationFeature", ScoreProducer::unlimited)
|
||||
{
|
||||
std::cerr << "Initializing word translation feature.. " << endl;
|
||||
|
||||
vector<string> tokens = Tokenize(line);
|
||||
//CHECK(tokens[0] == m_description);
|
||||
|
||||
if (tokens.size() != 1 && !(tokens.size() >= 4 && tokens.size() <= 9)) {
|
||||
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src>-<factor-tgt> "
|
||||
"[simple source-trigger target-trigger] [ignore-punctuation] [domain-trigger] [filename-src] [filename-tgt] [text-type]");
|
||||
//return false;
|
||||
}
|
||||
|
||||
// set factor
|
||||
vector <string> factors = Tokenize(tokens[1],"-");
|
||||
m_factorTypeSource = Scan<FactorType>(factors[0]);
|
||||
m_factorTypeTarget = Scan<FactorType>(factors[1]);
|
||||
|
||||
m_unrestricted = true;
|
||||
m_sparseProducerWeight = 1;
|
||||
m_simple = true;
|
||||
m_sourceContext = false;
|
||||
m_targetContext = false;
|
||||
m_ignorePunctuation = false;
|
||||
m_domainTrigger = false;
|
||||
if (tokens.size() >= 5) {
|
||||
m_simple = Scan<size_t>(tokens[2]);
|
||||
m_sourceContext = Scan<size_t>(tokens[3]);
|
||||
m_targetContext = Scan<size_t>(tokens[4]);
|
||||
}
|
||||
if (tokens.size() >= 6) {
|
||||
m_ignorePunctuation = Scan<size_t>(tokens[5]);
|
||||
}
|
||||
|
||||
if (tokens.size() >= 7) {
|
||||
m_domainTrigger = Scan<size_t>(tokens[6]);
|
||||
}
|
||||
|
||||
if (m_simple == 1) std::cerr << "using simple word translations.. ";
|
||||
if (m_sourceContext == 1) std::cerr << "using source context.. ";
|
||||
if (m_targetContext == 1) std::cerr << "using target context.. ";
|
||||
if (m_domainTrigger == 1) std::cerr << "using domain triggers.. ";
|
||||
|
||||
// compile a list of punctuation characters
|
||||
if (m_ignorePunctuation) {
|
||||
std::cerr << "ignoring punctuation for triggers.. ";
|
||||
char punctuation[] = "\"'!?¿·()#_,.:;•&@‑/\\0123456789~=";
|
||||
for (size_t i=0; i < sizeof(punctuation)-1; ++i) {
|
||||
m_punctuationHash[punctuation[i]] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "done." << std::endl;
|
||||
|
||||
// load word list for restricted feature set
|
||||
if (tokens.size() == 9) {
|
||||
string filenameSource = tokens[7];
|
||||
if (m_domainTrigger) {
|
||||
const string &texttype = tokens[8];
|
||||
|
||||
stringstream filename(filenameSource + "." + texttype);
|
||||
filenameSource = filename.str();
|
||||
cerr << "loading word translation term list from " << filenameSource << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "loading word translation word lists from " << filenameSource << endl;
|
||||
}
|
||||
if (!Load(filenameSource, "")) {
|
||||
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource);
|
||||
//return false;
|
||||
}
|
||||
} // if (tokens.size() == 7)
|
||||
else if (tokens.size() == 10) {
|
||||
// TODO need to change this
|
||||
string filenameSource = tokens[7];
|
||||
string filenameTarget = tokens[8];
|
||||
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
|
||||
if (!Load(filenameSource, filenameTarget)) {
|
||||
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
|
||||
//return false;
|
||||
}
|
||||
} //else if (tokens.size() == 8) {
|
||||
|
||||
// TODO not sure about this
|
||||
/*
|
||||
if (weight[0] != 1) {
|
||||
AddSparseProducer(wordTranslationFeature);
|
||||
cerr << "wt sparse producer weight: " << weight[0] << endl;
|
||||
if (m_mira)
|
||||
m_metaFeatureProducer = new MetaFeatureProducer("wt");
|
||||
}
|
||||
|
||||
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
|
||||
wordTranslationFeature->SetSparseFeatureReporting();
|
||||
}
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
bool WordTranslationFeature::Load(const std::string &filePathSource, const std::string &filePathTarget)
|
||||
{
|
||||
if (m_domainTrigger) {
|
||||
|
@ -36,36 +36,7 @@ private:
|
||||
CharHash m_punctuationHash;
|
||||
|
||||
public:
|
||||
WordTranslationFeature(FactorType factorTypeSource, FactorType factorTypeTarget,
|
||||
bool simple, bool sourceContext, bool targetContext, bool ignorePunctuation,
|
||||
bool domainTrigger):
|
||||
StatelessFeatureFunction("WordTranslationFeature", ScoreProducer::unlimited),
|
||||
m_factorTypeSource(factorTypeSource),
|
||||
m_factorTypeTarget(factorTypeTarget),
|
||||
m_unrestricted(true),
|
||||
m_simple(simple),
|
||||
m_sourceContext(sourceContext),
|
||||
m_targetContext(targetContext),
|
||||
m_domainTrigger(domainTrigger),
|
||||
m_sparseProducerWeight(1),
|
||||
m_ignorePunctuation(ignorePunctuation)
|
||||
{
|
||||
std::cerr << "Initializing word translation feature.. ";
|
||||
if (m_simple == 1) std::cerr << "using simple word translations.. ";
|
||||
if (m_sourceContext == 1) std::cerr << "using source context.. ";
|
||||
if (m_targetContext == 1) std::cerr << "using target context.. ";
|
||||
if (m_domainTrigger == 1) std::cerr << "using domain triggers.. ";
|
||||
|
||||
// compile a list of punctuation characters
|
||||
if (m_ignorePunctuation) {
|
||||
std::cerr << "ignoring punctuation for triggers.. ";
|
||||
char punctuation[] = "\"'!?¿·()#_,.:;•&@‑/\\0123456789~=";
|
||||
for (size_t i=0; i < sizeof(punctuation)-1; ++i)
|
||||
m_punctuationHash[punctuation[i]] = 1;
|
||||
}
|
||||
|
||||
std::cerr << "done." << std::endl;
|
||||
}
|
||||
WordTranslationFeature(const std::string &line);
|
||||
|
||||
bool Load(const std::string &filePathSource, const std::string &filePathTarget);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user