add [feature] arg. Use for GlobalLexicalModelUnlimited. Not tested

This commit is contained in:
Hieu Hoang 2013-01-01 17:27:26 +00:00
parent 0784921314
commit de45d7076a
4 changed files with 84 additions and 81 deletions

View File

@ -8,6 +8,61 @@ using namespace std;
namespace Moses
{
GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line)
:StatelessFeatureFunction("glm",ScoreProducer::unlimited)
{
const vector<string> modelSpec = Tokenize(line);
for (size_t i = 0; i < modelSpec.size(); i++ ) {
bool ignorePunctuation = true, biasFeature = false, restricted = false;
size_t context = 0;
string filenameSource, filenameTarget;
vector< string > factors;
vector< string > spec = Tokenize(modelSpec[i]," ");
// read optional punctuation and bias specifications
if (spec.size() > 0) {
if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) {
UserMessage::Add("Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] "
"[context-type] [filename-src filename-tgt]");
//return false;
}
factors = Tokenize(spec[0],"-");
if (spec.size() >= 2)
ignorePunctuation = Scan<size_t>(spec[1]);
if (spec.size() >= 3)
biasFeature = Scan<size_t>(spec[2]);
if (spec.size() >= 4)
context = Scan<size_t>(spec[3]);
if (spec.size() == 6) {
filenameSource = spec[4];
filenameTarget = spec[5];
restricted = true;
}
}
else
factors = Tokenize(modelSpec[i],"-");
if ( factors.size() != 2 ) {
UserMessage::Add("Wrong factor definition for global lexical model unlimited: " + modelSpec[i]);
//return false;
}
const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],",");
const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],",");
throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature");
GlobalLexicalModelUnlimited* glmu = NULL; // new GlobalLexicalModelUnlimited(inputFactors, outputFactors, biasFeature, ignorePunctuation, context);
if (restricted) {
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!glmu->Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
//return false;
}
}
}
}
bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource,
const std::string &filePathTarget)

View File

@ -68,6 +68,7 @@ private:
std::set<std::string> m_vocabTarget;
public:
GlobalLexicalModelUnlimited(const std::string &line);
GlobalLexicalModelUnlimited(const std::vector< FactorType >& inFactors, const std::vector< FactorType >& outFactors,
bool biasFeature, bool ignorePunctuation, size_t context):
StatelessFeatureFunction("glm",ScoreProducer::unlimited),

View File

@ -83,6 +83,21 @@ static size_t CalcMax(size_t x, const vector<size_t>& y, const vector<size_t>& z
return max;
}
int GetFeatureIndex(std::map<string, int> &map, const string &featureName)
{
std::map<string, int>::iterator iter;
iter = map.find(featureName);
if (iter == map.end()) {
map[featureName] = 0;
return 0;
}
else {
int &index = iter->second;
++index;
return index;
}
}
StaticData StaticData::s_instance;
StaticData::StaticData()
@ -530,14 +545,24 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
}
// all features
map<string, int> featureIndexMap;
const vector<string> &features = m_parameter->GetParam("feature");
for (size_t i = 0; i < features.size(); ++i) {
const string &line = features[i];
vector<string> toks = Tokenize(line);
if (toks[0] == "GlobalLexicalModel") {
const string &feature = toks[0];
int featureIndex = GetFeatureIndex(featureIndexMap, feature);
if (feature == "GlobalLexicalModel") {
GlobalLexicalModel *model = new GlobalLexicalModel(line);
const vector<float> &weights = m_parameter->GetWeights(toks[0], 0);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
else if (feature == "glm") {
GlobalLexicalModelUnlimited *model = NULL; //new GlobalLexicalModelUnlimited(line);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
@ -553,7 +578,6 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
if (!LoadLanguageModels()) return false;
if (!LoadGenerationTables()) return false;
if (!LoadPhraseTables()) return false;
if (!LoadGlobalLexicalModelUnlimited()) return false;
if (!LoadDecodeGraphs()) return false;
if (!LoadReferences()) return false;
if (!LoadDiscrimLMFeature()) return false;
@ -601,16 +625,7 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
UserMessage::Add("Unable to load weights from " + extraWeightConfig[0]);
return false;
}
// GLM: apply additional weight to sparse features if applicable
for (size_t i = 0; i < m_globalLexicalModelsUnlimited.size(); ++i) {
float weight = m_globalLexicalModelsUnlimited[i]->GetSparseProducerWeight();
if (weight != 1) {
AddSparseProducer(m_globalLexicalModelsUnlimited[i]);
cerr << "glm sparse producer weight: " << weight << endl;
}
}
m_allWeights.PlusEquals(extraWeights);
}
@ -778,72 +793,6 @@ bool StaticData::LoadLexicalReorderingModel()
return true;
}
bool StaticData::LoadGlobalLexicalModelUnlimited()
{
const vector<float> &weight = Scan<float>(m_parameter->GetParam("weight-glm"));
const vector<string> &modelSpec = m_parameter->GetParam("glm-feature");
if (weight.size() != 0 && weight.size() != modelSpec.size()) {
std::cerr << "number of sparse producer weights and model specs for the global lexical model unlimited "
"does not match (" << weight.size() << " != " << modelSpec.size() << ")" << std::endl;
return false;
}
for (size_t i = 0; i < modelSpec.size(); i++ ) {
bool ignorePunctuation = true, biasFeature = false, restricted = false;
size_t context = 0;
string filenameSource, filenameTarget;
vector< string > factors;
vector< string > spec = Tokenize(modelSpec[i]," ");
// read optional punctuation and bias specifications
if (spec.size() > 0) {
if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) {
UserMessage::Add("Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] "
"[context-type] [filename-src filename-tgt]");
return false;
}
factors = Tokenize(spec[0],"-");
if (spec.size() >= 2)
ignorePunctuation = Scan<size_t>(spec[1]);
if (spec.size() >= 3)
biasFeature = Scan<size_t>(spec[2]);
if (spec.size() >= 4)
context = Scan<size_t>(spec[3]);
if (spec.size() == 6) {
filenameSource = spec[4];
filenameTarget = spec[5];
restricted = true;
}
}
else
factors = Tokenize(modelSpec[i],"-");
if ( factors.size() != 2 ) {
UserMessage::Add("Wrong factor definition for global lexical model unlimited: " + modelSpec[i]);
return false;
}
const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],",");
const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],",");
throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature");
GlobalLexicalModelUnlimited* glmu = NULL; // new GlobalLexicalModelUnlimited(inputFactors, outputFactors, biasFeature, ignorePunctuation, context);
m_globalLexicalModelsUnlimited.push_back(glmu);
if (restricted) {
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!glmu->Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
return false;
}
}
if (weight.size() > i)
m_globalLexicalModelsUnlimited[i]->SetSparseProducerWeight(weight[i]);
}
return true;
}
bool StaticData::LoadLanguageModels()
{
if (m_parameter->GetParam("lmodel-file").size() > 0) {

View File

@ -89,7 +89,6 @@ protected:
LMList m_languageModel;
ScoreComponentCollection m_allWeights;
std::vector<LexicalReordering*> m_reorderModels;
std::vector<GlobalLexicalModelUnlimited*> m_globalLexicalModelsUnlimited;
#ifdef HAVE_SYNLM
SyntacticLanguageModel* m_syntacticLanguageModel;
#endif
@ -249,7 +248,6 @@ protected:
//! load decoding steps
bool LoadDecodeGraphs();
bool LoadLexicalReorderingModel();
bool LoadGlobalLexicalModelUnlimited();
//References used for scoring feature (eg BleuScoreFeature) for online training
bool LoadReferences();
bool LoadDiscrimLMFeature();