This commit is contained in:
Hieu Hoang 2013-08-14 12:37:22 +01:00
parent bfdedf682b
commit 8419a3af35
5 changed files with 39 additions and 40 deletions

View File

@ -172,7 +172,7 @@ int main(int argc, char** argv)
("sparse-r0", po::value<float>(&sparse_r0)->default_value(1.0), "Start learning rate for sparse features") ("sparse-r0", po::value<float>(&sparse_r0)->default_value(1.0), "Start learning rate for sparse features")
("decay-core", po::value<float>(&decay_core)->default_value(0.01), "Decay for core feature learning rate") ("decay-core", po::value<float>(&decay_core)->default_value(0.01), "Decay for core feature learning rate")
("decay-sparse", po::value<float>(&decay_sparse)->default_value(0.01), "Decay for sparse feature learning rate") ("decay-sparse", po::value<float>(&decay_sparse)->default_value(0.01), "Decay for sparse feature learning rate")
("tie-bw-to-lm", po::value<bool>(&bleu_weight_lm)->default_value(true), "Make bleu weight depend on lm weight") ("tie-bw-to-lm", po::value<bool>(&bleu_weight_lm)->default_value(true), "Make bleu weight depend on lm weight")
("bw-lm-factor", po::value<float>(&bleu_weight_lm_factor)->default_value(2.0), "Make bleu weight depend on lm weight by this factor") ("bw-lm-factor", po::value<float>(&bleu_weight_lm_factor)->default_value(2.0), "Make bleu weight depend on lm weight by this factor")
("bw-factor-fear", po::value<float>(&bleu_weight_fear_factor)->default_value(1.0), "Multiply fear weight by this factor") ("bw-factor-fear", po::value<float>(&bleu_weight_fear_factor)->default_value(1.0), "Multiply fear weight by this factor")
@ -450,8 +450,8 @@ int main(int argc, char** argv)
if (normaliseMargin) if (normaliseMargin)
cerr << "sigmoid parameter: " << sigmoidParam << endl; cerr << "sigmoid parameter: " << sigmoidParam << endl;
} }
optimiser = new MiraOptimiser(slack, scale_margin, scale_update, boost, normaliseMargin, sigmoidParam); optimiser = new MiraOptimiser(slack, scale_margin, scale_update, boost, normaliseMargin, sigmoidParam);
learning_rate = mira_learning_rate; learning_rate = mira_learning_rate;
perceptron_update = false; perceptron_update = false;
} else if (learner == "perceptron") { } else if (learner == "perceptron") {
if (rank == 0) { if (rank == 0) {
@ -1026,20 +1026,19 @@ int main(int argc, char** argv)
float bleuDiff = bleuHope - bleuScores[batchPosition][i]; float bleuDiff = bleuHope - bleuScores[batchPosition][i];
float modelDiff = modelScores[batchPosition][indexHope] - modelScores[batchPosition][i]; float modelDiff = modelScores[batchPosition][indexHope] - modelScores[batchPosition][i];
if ((bleuDiff > epsilon) && (modelDiff < bleuDiff)) { if ((bleuDiff > epsilon) && (modelDiff < bleuDiff)) {
float diff = bleuDiff - modelDiff; float diff = bleuDiff - modelDiff;
if (diff > epsilon) { if (diff > epsilon) {
if (all_violated) { if (all_violated) {
cerr << ".. adding pair"; cerr << ".. adding pair";
bleuHopeList.push_back(bleuHope); bleuHopeList.push_back(bleuHope);
bleuFearList.push_back(bleuScores[batchPosition][i]); bleuFearList.push_back(bleuScores[batchPosition][i]);
indexHopeList.push_back(indexHope); indexHopeList.push_back(indexHope);
indexFearList.push_back(i); indexFearList.push_back(i);
} } else if (most_violated && diff > currentViolation) {
else if (most_violated && diff > currentViolation) { currentViolation = diff;
currentViolation = diff; bleuFear = bleuScores[batchPosition][i];
bleuFear = bleuScores[batchPosition][i]; indexFear = i;
indexFear = i; cerr << "Rank " << rank << ", epoch " << epoch << ", current violation: " << currentViolation << " (" << modelDiff << " >= " << bleuDiff << ")" << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", current violation: " << currentViolation << " (" << modelDiff << " >= " << bleuDiff << ")" << endl;
} }
} }
} }

View File

@ -70,16 +70,16 @@ public:
Optimiser() { } Optimiser() { }
MiraOptimiser(float slack) : MiraOptimiser(float slack) :
Optimiser(), Optimiser(),
m_slack(slack), m_slack(slack),
m_scale_margin(false), m_scale_margin(false),
m_scale_update(false), m_scale_update(false),
m_boost(false), m_boost(false),
m_normaliseMargin(false), m_normaliseMargin(false),
m_sigmoidParam(1.0) { } m_sigmoidParam(1.0) { }
MiraOptimiser(float slack, bool scale_margin, bool scale_update, MiraOptimiser(float slack, bool scale_margin, bool scale_update,
bool boost, bool normaliseMargin, float sigmoidParam) : bool boost, bool normaliseMargin, float sigmoidParam) :
Optimiser(), Optimiser(),
m_slack(slack), m_slack(slack),
m_scale_margin(scale_margin), m_scale_margin(scale_margin),

View File

@ -223,7 +223,7 @@ void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<
out << "|" << sourceStart << "-" << sourceEnd; out << "|" << sourceStart << "-" << sourceEnd;
// enriched "-tt" // enriched "-tt"
if (reportSegmentation == 2) { if (reportSegmentation == 2) {
out << ",0, "; out << ",0, ";
const AlignmentInfo &ai = edge.GetCurrTargetPhrase().GetAlignTerm(); const AlignmentInfo &ai = edge.GetCurrTargetPhrase().GetAlignTerm();
OutputAlignment(out, ai, 0, 0); OutputAlignment(out, ai, 0, 0);
} }

View File

@ -143,8 +143,8 @@ void BleuScoreFeature::SetParameter(const std::string& key, const std::string& v
std::vector<float> BleuScoreFeature::DefaultWeights() const std::vector<float> BleuScoreFeature::DefaultWeights() const
{ {
std::vector<float> ret(m_numScoreComponents, 1); std::vector<float> ret(m_numScoreComponents, 1);
return ret; return ret;
} }
void BleuScoreFeature::PrintHistory(std::ostream& out) const void BleuScoreFeature::PrintHistory(std::ostream& out) const

View File

@ -361,15 +361,15 @@ bool Parameter::LoadParam(int argc, char* argv[])
void Parameter::AddFeaturesCmd() void Parameter::AddFeaturesCmd()
{ {
if (!isParamSpecified("feature-add")) { if (!isParamSpecified("feature-add")) {
return; return;
} }
const PARAM_VEC &params = GetParam("feature-add"); const PARAM_VEC &params = GetParam("feature-add");
PARAM_VEC::const_iterator iter; PARAM_VEC::const_iterator iter;
for (iter = params.begin(); iter != params.end(); ++iter) { for (iter = params.begin(); iter != params.end(); ++iter) {
const string &line = *iter; const string &line = *iter;
AddFeature(line); AddFeature(line);
} }
m_setting.erase("feature-add"); m_setting.erase("feature-add");
@ -838,15 +838,15 @@ void Parameter::ConvertWeightArgsWordPenalty()
void Parameter::ConvertPhrasePenalty() void Parameter::ConvertPhrasePenalty()
{ {
string oldWeightName = "weight-p"; string oldWeightName = "weight-p";
if (isParamSpecified(oldWeightName)) { if (isParamSpecified(oldWeightName)) {
CHECK(m_setting[oldWeightName].size() == 1); CHECK(m_setting[oldWeightName].size() == 1);
float weight = Scan<float>(m_setting[oldWeightName][0]); float weight = Scan<float>(m_setting[oldWeightName][0]);
AddFeature("PhrasePenalty"); AddFeature("PhrasePenalty");
SetWeight("PhrasePenalty", 0, weight); SetWeight("PhrasePenalty", 0, weight);
m_setting.erase(oldWeightName); m_setting.erase(oldWeightName);
} }
} }
void Parameter::ConvertWeightArgs() void Parameter::ConvertWeightArgs()