fix weight scaling

This commit is contained in:
Eva Hasler 2012-05-02 22:54:23 +01:00
parent 702c9d1b81
commit dd9bd42a62
3 changed files with 80 additions and 43 deletions

View File

@ -117,8 +117,8 @@ int main(int argc, char** argv) {
float scale_lm_factor, bleu_weight_lm_factor, scale_wp_factor; float scale_lm_factor, bleu_weight_lm_factor, scale_wp_factor;
bool sample; bool sample;
string moses_src; string moses_src;
bool external_score = false, scale_all; bool external_score = false, scale_all, dummy;
float dummy, sigmoidParam, scale_all_factor; float sigmoidParam, scale_all_factor;
po::options_description desc("Allowed options"); po::options_description desc("Allowed options");
desc.add_options() desc.add_options()
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs") ("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs")
@ -145,7 +145,7 @@ int main(int argc, char** argv) {
("decode-filename", po::value<string>(&decode_filename), "Filename for Bleu objective translations") ("decode-filename", po::value<string>(&decode_filename), "Filename for Bleu objective translations")
("decoder-settings", po::value<string>(&decoder_settings)->default_value(""), "Decoder settings for tuning runs") ("decoder-settings", po::value<string>(&decoder_settings)->default_value(""), "Decoder settings for tuning runs")
("distinct-nbest", po::value<bool>(&distinctNbest)->default_value(true), "Use n-best list with distinct translations in inference step") ("distinct-nbest", po::value<bool>(&distinctNbest)->default_value(true), "Use n-best list with distinct translations in inference step")
("dummy", po::value<float>(&dummy)->default_value(1.0), "****") ("dummy", po::value<bool>(&dummy)->default_value(false), "****")
("dump-mixed-weights", po::value<bool>(&dumpMixedWeights)->default_value(false), "Dump mixed weights instead of averaged weights") ("dump-mixed-weights", po::value<bool>(&dumpMixedWeights)->default_value(false), "Dump mixed weights instead of averaged weights")
("epochs,e", po::value<size_t>(&epochs)->default_value(10), "Number of epochs") ("epochs,e", po::value<size_t>(&epochs)->default_value(10), "Number of epochs")
("feature-cutoff", po::value<int>(&featureCutoff)->default_value(-1), "Feature cutoff as additional regularization for sparse features") ("feature-cutoff", po::value<int>(&featureCutoff)->default_value(-1), "Feature cutoff as additional regularization for sparse features")
@ -567,7 +567,13 @@ int main(int argc, char** argv) {
} }
decoder->setWeights(initialWeights); decoder->setWeights(initialWeights);
if (dummy == true) {
scale_all = true;
scale_all_factor = 2;
}
if (scale_all) { if (scale_all) {
cerr << "Scale all core features by factor " << scale_all_factor << endl;
scale_lm = true; scale_lm = true;
scale_wp = true; scale_wp = true;
scale_lm_factor = scale_all_factor; scale_lm_factor = scale_all_factor;
@ -583,9 +589,6 @@ int main(int argc, char** argv) {
bleuWeight = lmSum * bleu_weight_lm_factor; bleuWeight = lmSum * bleu_weight_lm_factor;
} }
if (dummy != 1.0)
bleuWeight = dummy;
if (bleuWeight_hope == -1) { if (bleuWeight_hope == -1) {
bleuWeight_hope = bleuWeight; bleuWeight_hope = bleuWeight;
} }
@ -1373,7 +1376,7 @@ int main(int argc, char** argv) {
// scale WP // scale WP
if (scale_wp) { if (scale_wp) {
// scale up weight // scale up weight
WordPenaltyProducer *wp = staticData.GetWordPenaltyProducer(); WordPenaltyProducer *wp = staticData.GetFirstWordPenaltyProducer();
float wpWeight = mosesWeights.GetScoreForProducer(wp); float wpWeight = mosesWeights.GetScoreForProducer(wp);
mosesWeights.Assign(wp, wpWeight*scale_wp_factor); mosesWeights.Assign(wp, wpWeight*scale_wp_factor);
cerr << "Rank " << rank << ", epoch " << epoch << ", wp weight scaled from " << wpWeight << " to " << wpWeight*scale_wp_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", wp weight scaled from " << wpWeight << " to " << wpWeight*scale_wp_factor << endl;
@ -1409,42 +1412,52 @@ int main(int argc, char** argv) {
} }
// scale lexical reordering models // scale lexical reordering models
vector<LexicalReordering*> lr = staticData.GetLexicalReorderModels(); vector<LexicalReordering*> lrVec = staticData.GetLexicalReorderModels();
for (size_t i=0; i<lr.size(); ++i) { for (size_t i=0; i<lrVec.size(); ++i) {
LexicalReordering* lr = lrVec[i];
// scale up weight // scale up weight
dWeight = mosesWeights.GetScoreForProducer(lr[i]); vector<float> dWeights = mosesWeights.GetScoresForProducer(lr);
mosesWeights.Assign(lr[i], dWeight*scale_all_factor); for (size_t j=0; j<dWeights.size(); ++j) {
cerr << "Rank " << rank << ", epoch " << epoch << ", d weight scaled from " << dWeight << " to " << dWeight*scale_all_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", d weight scaled from " << dWeights[j];
dWeights[j] *= scale_all_factor;
// scale down score cerr << " to " << dWeights[j] << endl;
if (sample) { }
scaleFeatureScore(lr[i], scale_all_factor, featureValuesHopeSample, rank, epoch); mosesWeights.Assign(lr, dWeights);
scaleFeatureScore(lr[i], scale_all_factor, featureValuesFearSample, rank, epoch);
} // scale down score
else { if (sample) {
scaleFeatureScore(lr[i], scale_all_factor, featureValuesHope, rank, epoch); scaleFeatureScores(lr, scale_all_factor, featureValuesHopeSample, rank, epoch);
scaleFeatureScore(lr[i], scale_all_factor, featureValuesFear, rank, epoch); scaleFeatureScores(lr, scale_all_factor, featureValuesFearSample, rank, epoch);
scaleFeatureScore(lr[i], scale_all_factor, featureValues, rank, epoch); }
} else {
scaleFeatureScores(lr, scale_all_factor, featureValuesHope, rank, epoch);
scaleFeatureScores(lr, scale_all_factor, featureValuesFear, rank, epoch);
scaleFeatureScores(lr, scale_all_factor, featureValues, rank, epoch);
}
} }
// scale phrase table models // scale phrase table models
vector<PhraseDictionaryFeature*> pd = staticData.GetPhraseDictionaryModels(); vector<PhraseDictionaryFeature*> pdVec = staticData.GetPhraseDictionaryModels();
for (size_t i=0; i<pd.size(); ++i) { for (size_t i=0; i<pdVec.size(); ++i) {
PhraseDictionaryFeature* pd = pdVec[i];
// scale up weight // scale up weight
float tWeight = mosesWeights.GetScoreForProducer(pd[i]); vector<float> tWeights = mosesWeights.GetScoresForProducer(pd);
mosesWeights.Assign(pd[i], tWeight*scale_all_factor); for (size_t j=0; j<tWeights.size(); ++j) {
cerr << "Rank " << rank << ", epoch " << epoch << ", t weight scaled from " << tWeight << " to " << tWeight*scale_all_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", t weight scaled from " << tWeights[j];
tWeights[j] *= scale_all_factor;
cerr << " to " << tWeights[j] << endl;
}
mosesWeights.Assign(pd, tWeights);
// scale down score // scale down score
if (sample) { if (sample) {
scaleFeatureScore(pd[i], scale_all_factor, featureValuesHopeSample, rank, epoch); scaleFeatureScores(pd, scale_all_factor, featureValuesHopeSample, rank, epoch);
scaleFeatureScore(pd[i], scale_all_factor, featureValuesFearSample, rank, epoch); scaleFeatureScores(pd, scale_all_factor, featureValuesFearSample, rank, epoch);
} }
else { else {
scaleFeatureScore(pd[i], scale_all_factor, featureValuesHope, rank, epoch); scaleFeatureScores(pd, scale_all_factor, featureValuesHope, rank, epoch);
scaleFeatureScore(pd[i], scale_all_factor, featureValuesFear, rank, epoch); scaleFeatureScores(pd, scale_all_factor, featureValuesFear, rank, epoch);
scaleFeatureScore(pd[i], scale_all_factor, featureValues, rank, epoch); scaleFeatureScores(pd, scale_all_factor, featureValues, rank, epoch);
} }
} }
} }
@ -1539,7 +1552,7 @@ int main(int argc, char** argv) {
// rescale WP feature // rescale WP feature
if (scale_wp) { if (scale_wp) {
// scale weight back down // scale weight back down
WordPenaltyProducer *wp = staticData.GetWordPenaltyProducer(); WordPenaltyProducer *wp = staticData.GetFirstWordPenaltyProducer();
float wpWeight = mosesWeights.GetScoreForProducer(wp); float wpWeight = mosesWeights.GetScoreForProducer(wp);
mosesWeights.Assign(wp, wpWeight/scale_wp_factor); mosesWeights.Assign(wp, wpWeight/scale_wp_factor);
cerr << "Rank " << rank << ", epoch " << epoch << ", wp weight rescaled from " << wpWeight << " to " << wpWeight/scale_wp_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", wp weight rescaled from " << wpWeight << " to " << wpWeight/scale_wp_factor << endl;
@ -1555,17 +1568,25 @@ int main(int argc, char** argv) {
// rescale lexical reordering // rescale lexical reordering
vector<LexicalReordering*> lr = staticData.GetLexicalReorderModels(); vector<LexicalReordering*> lr = staticData.GetLexicalReorderModels();
for (size_t i=0; i<lr.size(); ++i) { for (size_t i=0; i<lr.size(); ++i) {
dWeight = mosesWeights.GetScoreForProducer(lr[i]); vector<float> dWeights = mosesWeights.GetScoresForProducer(lr[i]);
mosesWeights.Assign(lr[i], dWeight/scale_all_factor); for (size_t j=0; j<dWeights.size(); ++j) {
cerr << "Rank " << rank << ", epoch " << epoch << ", d weight rescaled from " << dWeight << " to " << dWeight/scale_all_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", d weight rescaled from " << dWeights[j];
dWeights[j] /=scale_all_factor;
cerr << " to " << dWeights[j] << endl;
}
mosesWeights.Assign(lr[i], dWeights);
} }
// rescale phrase models // rescale phrase models
vector<PhraseDictionaryFeature*> pd = staticData.GetPhraseDictionaryModels(); vector<PhraseDictionaryFeature*> pd = staticData.GetPhraseDictionaryModels();
for (size_t i=0; i<pd.size(); ++i) { for (size_t i=0; i<pd.size(); ++i) {
float tWeight = mosesWeights.GetScoreForProducer(pd[i]); vector<float> tWeights = mosesWeights.GetScoresForProducer(pd[i]);
mosesWeights.Assign(pd[i], tWeight/scale_all_factor); for (size_t j=0; j<tWeights.size(); ++j) {
cerr << "Rank " << rank << ", epoch " << epoch << ", t weight rescaled from " << tWeight << " to " << tWeight/scale_all_factor << endl; cerr << "Rank " << rank << ", epoch " << epoch << ", t weight rescaled from " << tWeights[j];
tWeights[j] /=scale_all_factor;
cerr << " to " << tWeights[j] << endl;
}
mosesWeights.Assign(pd[i], tWeights);
} }
} }
@ -2118,3 +2139,18 @@ void scaleFeatureScore(ScoreProducer *sp, float scaling_factor, vector<vector<Sc
} }
} }
} }
void scaleFeatureScores(ScoreProducer *sp, float scaling_factor, vector<vector<ScoreComponentCollection> > &featureValues, size_t rank, size_t epoch) {
string name = sp->GetScoreProducerWeightShortName();
// scale down score
for (size_t i=0; i<featureValues.size(); ++i) { // each item in batch
for (size_t j=0; j<featureValues[i].size(); ++j) { // each item in nbest
vector<float> featureScores = featureValues[i][j].GetScoresForProducer(sp);
for (size_t k=0; k<featureScores.size(); ++k)
featureScores[k] /= scaling_factor;
featureValues[i][j].Assign(sp, featureScores);
//cerr << "Rank " << rank << ", epoch " << epoch << ", " << name << " score scaled from " << featureScore << " to " << featureScore/scaling_factor << endl;
}
}
}

View File

@ -52,5 +52,6 @@ void takeLogs(std::vector<std::vector<Moses::ScoreComponentCollection> > &featur
void deleteTranslations(std::vector<std::vector<const Moses::Word*> > &translations); void deleteTranslations(std::vector<std::vector<const Moses::Word*> > &translations);
void decodeHopeOrFear(size_t rank, size_t size, size_t decode, std::string decode_filename, std::vector<std::string> &inputSentences, Mira::MosesDecoder* decoder, size_t n); void decodeHopeOrFear(size_t rank, size_t size, size_t decode, std::string decode_filename, std::vector<std::string> &inputSentences, Mira::MosesDecoder* decoder, size_t n);
void scaleFeatureScore(Moses::ScoreProducer *sp, float scaling_factor, std::vector<std::vector<Moses::ScoreComponentCollection> > &featureValues, size_t rank, size_t epoch); void scaleFeatureScore(Moses::ScoreProducer *sp, float scaling_factor, std::vector<std::vector<Moses::ScoreComponentCollection> > &featureValues, size_t rank, size_t epoch);
void scaleFeatureScores(Moses::ScoreProducer *sp, float scaling_factor, std::vector<std::vector<Moses::ScoreComponentCollection> > &featureValues, size_t rank, size_t epoch);
#endif /* MAIN_H_ */ #endif /* MAIN_H_ */

View File

@ -485,7 +485,7 @@ public:
LMList GetLMList() const { LMList GetLMList() const {
return m_languageModel; return m_languageModel;
} }
WordPenaltyProducer* GetWordPenaltyProducer() const { WordPenaltyProducer* GetFirstWordPenaltyProducer() const {
assert(m_wordPenaltyProducers.size() >= 1); assert(m_wordPenaltyProducers.size() >= 1);
return m_wordPenaltyProducers[0]; return m_wordPenaltyProducers[0];
} }