input weight formatted according to discussions with Ken & Barry

This commit is contained in:
hieu 2012-12-12 18:30:11 +00:00
parent 5ab3c3f081
commit e84f7d6292
2 changed files with 112 additions and 44 deletions

View File

@ -325,15 +325,45 @@ std::vector<float> &Parameter::GetWeights(const std::string &name, size_t ind)
return m_weights[name + SPrint(ind)];
}
void Parameter::AddWeight(const std::string &name, size_t ind, float weight)
void Parameter::SetWeight(const std::string &name, size_t ind, float weight)
{
PARAM_VEC &newWeights = m_setting["weight"];
string line = name + SPrint(ind) + " " + SPrint(weight);
string line = name + SPrint(ind) + "= " + SPrint(weight);
newWeights.push_back(line);
}
void Parameter::ConvertWeightArgsDefault(const string &oldWeightName, const string &newWeightName)
void Parameter::SetWeight(const std::string &name, size_t ind, const vector<float> &weights)
{
PARAM_VEC &newWeights = m_setting["weight"];
string line = name + SPrint(ind) + "=";
for (size_t i = 0; i < weights.size(); ++i) {
line += " " + SPrint(weights[i]);
}
newWeights.push_back(line);
}
void Parameter::AddWeight(const std::string &name, size_t ind, const std::vector<float> &weights)
{
PARAM_VEC &newWeights = m_setting["weight"];
string sought = name + SPrint(ind) + "=";
for (size_t i = 0; i < newWeights.size(); ++i) {
string &line = newWeights[i];
if (line.find(sought) == 0) {
// found existing weight, most likely to be input weights. Append to this line
for (size_t i = 0; i < weights.size(); ++i) {
line += " " + SPrint(weights[i]);
}
return;
}
}
// nothing found. Just set
SetWeight(name, ind, weights);
}
void Parameter::ConvertWeightArgsSingleWeight(const string &oldWeightName, const string &newWeightName)
{
size_t ind = 0;
PARAM_MAP::iterator iterMap;
@ -344,7 +374,7 @@ void Parameter::ConvertWeightArgsDefault(const string &oldWeightName, const stri
const PARAM_VEC &weights = iterMap->second;
for (size_t i = 0; i < weights.size(); ++i)
{
AddWeight(newWeightName, ind, Scan<float>(weights[i]));
SetWeight(newWeightName, ind, Scan<float>(weights[i]));
}
m_setting.erase(iterMap);
@ -364,16 +394,60 @@ void Parameter::ConvertWeightArgsT(const string &oldWeightName, const string &ne
size_t numFFInd = (toks.size() == 4) ? 2 : 3;
size_t numFF = Scan<size_t>(toks[numFFInd]);
vector<float> weights(numFF);
for (size_t currFF = 0; currFF < numFF; ++currFF) {
CHECK(currOldInd < oldWeights.size());
float weight = Scan<float>(oldWeights[currOldInd]);
AddWeight(newWeightName, ttableInd, weight);
weights[currFF] = weight;
++currOldInd;
}
AddWeight(newWeightName, ttableInd, weights);
}
}
void Parameter::ConvertWeightArgsDistortion()
{
const string oldWeightName = "weight-d";
// distortion / lex distortion
PARAM_VEC &oldWeights = m_setting[oldWeightName];
if (oldWeights.size() > 0)
{
// distance distortion
SetWeight("Distortion", 0, Scan<float>(oldWeights[0]));
// everything but the last is lex reordering model
size_t currOldInd = 1;
PARAM_VEC &lextable = m_setting["distortion-file"];
for (size_t indTable = 0; indTable < lextable.size(); ++indTable) {
string &line = lextable[indTable];
vector<string> toks = Tokenize(line);
size_t numFF = Scan<size_t>(toks[2]);
vector<float> weights(numFF);
for (size_t currFF = 0; currFF < numFF; ++currFF)
{
CHECK(currOldInd < oldWeights.size());
float weight = Scan<float>(oldWeights[currOldInd]);
weights[currFF] = weight;
++currOldInd;
}
SetWeight("LexicalReordering", indTable, weights);
}
m_setting.erase(oldWeightName);
}
}
void Parameter::ConvertWeightArgs()
{
// check that old & new format aren't mixed
@ -400,43 +474,27 @@ void Parameter::ConvertWeightArgs()
numInputScores.push_back("1");
}
ConvertWeightArgsDefault("weight-i", "PhraseModel");
ConvertWeightArgsSingleWeight("weight-i", "PhraseModel");
ConvertWeightArgsT("weight-t", "PhraseModel");
ConvertWeightArgsDefault("weight-w", "WordPenalty");
ConvertWeightArgsDefault("weight-l", "LM");
ConvertWeightArgsDefault("weight-slm", "SyntacticLM");
ConvertWeightArgsDefault("weight-u", "UnknownWordPenalty");
ConvertWeightArgsDefault("weight-generation", "Generation");
ConvertWeightArgsSingleWeight("weight-w", "WordPenalty");
ConvertWeightArgsSingleWeight("weight-l", "LM");
ConvertWeightArgsSingleWeight("weight-slm", "SyntacticLM");
ConvertWeightArgsSingleWeight("weight-u", "UnknownWordPenalty");
ConvertWeightArgsSingleWeight("weight-generation", "Generation");
// don't know or can't be bothered converting these weights
ConvertWeightArgsDefault("weight-lr", "LexicalReordering");
ConvertWeightArgsDefault("weight-bl", "BleuScoreFeature");
ConvertWeightArgsDefault("weight-glm", "GlobalLexicalModel");
ConvertWeightArgsDefault("weight-wt", "WordTranslationFeature");
ConvertWeightArgsDefault("weight-pp", "PhrasePairFeature");
ConvertWeightArgsDefault("weight-pb", "PhraseBoundaryFeature");
ConvertWeightArgsDefault("weight-dlm", "DiscriminativeLM");
ConvertWeightArgsSingleWeight("weight-lr", "LexicalReordering");
ConvertWeightArgsSingleWeight("weight-bl", "BleuScoreFeature");
ConvertWeightArgsSingleWeight("weight-glm", "GlobalLexicalModel");
ConvertWeightArgsSingleWeight("weight-wt", "WordTranslationFeature");
ConvertWeightArgsSingleWeight("weight-pp", "PhrasePairFeature");
ConvertWeightArgsSingleWeight("weight-pb", "PhraseBoundaryFeature");
ConvertWeightArgsSingleWeight("weight-dlm", "DiscriminativeLM");
ConvertWeightArgsDefault("weight-e", "WordDeletion"); // TODO Can't find real name
ConvertWeightArgsDefault("weight-lex", "GlobalLexicalReordering"); // TODO Can't find real name
ConvertWeightArgsSingleWeight("weight-e", "WordDeletion"); // TODO Can't find real name
ConvertWeightArgsSingleWeight("weight-lex", "GlobalLexicalReordering"); // TODO Can't find real name
// distortion / lex distortion
PARAM_VEC &newWeights = m_setting["weight"];
PARAM_VEC &weights = m_setting["weight-d"];
if (weights.size() > 0)
{
// distance distortion
AddWeight("Distortion", 0, Scan<float>(weights[0]));
// everything but the last is lex reordering model
for (size_t i = 1; i < weights.size(); ++i)
{
AddWeight("LexicalReordering", 0, Scan<float>(weights[i]));
}
m_setting.erase("weight-d");
}
ConvertWeightArgsDistortion();
}
void Parameter::CreateWeightsMap()
@ -446,11 +504,18 @@ void Parameter::CreateWeightsMap()
{
const string &line = vec[i];
vector<string> toks = Tokenize(line);
CHECK(toks.size() == 2);
cerr << line << endl;
CHECK(toks.size() >= 2);
string &name = toks[0];
float weight = Scan<float>(toks[1]);
m_weights[name].push_back(weight);
string name = toks[0];
name = name.substr(0, name.size() - 1);
vector<float> weights(toks.size() - 1);
for (size_t i = 1; i < toks.size(); ++i) {
float weight = Scan<float>(toks[i]);
weights[i - 1] = weight;
}
m_weights[name] = weights;
}
}

View File

@ -62,10 +62,13 @@ protected:
void PrintCredit();
void AddWeight(const std::string &name, size_t ind, float weight);
void SetWeight(const std::string &name, size_t ind, float weight);
void SetWeight(const std::string &name, size_t ind, const std::vector<float> &weights);
void AddWeight(const std::string &name, size_t ind, const std::vector<float> &weights);
void ConvertWeightArgs();
void ConvertWeightArgsDefault(const std::string &oldWeightName, const std::string &newWeightName);
void ConvertWeightArgsSingleWeight(const std::string &oldWeightName, const std::string &newWeightName);
void ConvertWeightArgsT(const std::string &oldWeightName, const std::string &newWeightName);
void ConvertWeightArgsDistortion();
void CreateWeightsMap();
void WeightOverwrite();