mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
Introduce margin scale factor, change printing of feature vectors
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3715 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
ffc37e64b3
commit
57b0e64cda
@ -82,6 +82,7 @@ int main(int argc, char** argv) {
|
||||
size_t mixFrequency;
|
||||
size_t weightDumpFrequency;
|
||||
string weightDumpStem;
|
||||
float marginScaleFactor;
|
||||
float clipping;
|
||||
po::options_description desc("Allowed options");
|
||||
desc.add_options()
|
||||
@ -97,6 +98,7 @@ int main(int argc, char** argv) {
|
||||
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights")
|
||||
("shuffle", po::value<bool>(&shuffle)->default_value(false), "Shuffle input sentences before processing")
|
||||
("hildreth", po::value<bool>(&hildreth)->default_value(true), "Use Hildreth's optimisation algorithm")
|
||||
("margin-scale-factor,m", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
|
||||
("clipping", po::value<float>(&clipping)->default_value(0.01f), "Set a clipping threshold to regularise updates");
|
||||
|
||||
po::options_description cmdline_options;
|
||||
@ -188,7 +190,7 @@ int main(int argc, char** argv) {
|
||||
size_t n = 10; // size of n-best lists
|
||||
if (learner == "mira") {
|
||||
cerr << "Optimising using Mira" << endl;
|
||||
optimiser = new MiraOptimiser(n, hildreth, clipping);
|
||||
optimiser = new MiraOptimiser(n, hildreth, marginScaleFactor, clipping);
|
||||
if (hildreth) {
|
||||
cerr << "Using Hildreth's optimisation algorithm.." << endl;
|
||||
}
|
||||
@ -281,6 +283,7 @@ int main(int argc, char** argv) {
|
||||
bleuScores[batch],
|
||||
false);
|
||||
decoder->cleanup();
|
||||
cerr << "BLEU of oracle: " << oracleBleuScore << endl;
|
||||
|
||||
// Set loss for each sentence as BLEU(oracle) - BLEU(hypothesis)
|
||||
vector< vector<float> > losses(batchSize);
|
||||
|
@ -23,7 +23,10 @@ void MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
ScoreComponentCollection featureValueDiff = oracleFeatureValues;
|
||||
featureValueDiff.MinusEquals(featureValues[i][j]);
|
||||
float modelScoreDiff = featureValueDiff.InnerProduct(currWeights);
|
||||
if (modelScoreDiff < losses[i][j]) {
|
||||
cerr << "loss of hypothesis: " << losses[i][j] << endl;
|
||||
cerr << "model score difference: " << modelScoreDiff << endl;
|
||||
float loss = losses[i][j] * m_marginScaleFactor;
|
||||
if (modelScoreDiff < loss) {
|
||||
++violatedConstraintsBefore;
|
||||
}
|
||||
|
||||
@ -32,7 +35,7 @@ void MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
// 1. vector Delta_h_ij of the feature value differences (oracle - hypothesis)
|
||||
// 2. loss_ij - difference in model scores (Delta_h_ij.w') (oracle - hypothesis)
|
||||
featureValueDiffs.push_back(featureValueDiff);
|
||||
float lossMarginDistance = losses[i][j] - modelScoreDiff;
|
||||
float lossMarginDistance = loss - modelScoreDiff;
|
||||
lossMarginDistances.push_back(lossMarginDistance);
|
||||
}
|
||||
}
|
||||
@ -40,13 +43,12 @@ void MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
if (violatedConstraintsBefore > 0) {
|
||||
// TODO: slack? margin scale factor?
|
||||
// run optimisation
|
||||
cerr << "Number of violated constraints: " << violatedConstraintsBefore << endl;
|
||||
cerr << "\nNumber of violated constraints: " << violatedConstraintsBefore << endl;
|
||||
// compute deltas for all given constraints
|
||||
vector< float> deltas = Hildreth::optimise(featureValueDiffs, lossMarginDistances);
|
||||
|
||||
// Update the weight vector according to the deltas and the feature value differences
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
ScoreComponentCollection oldWeights(currWeights);
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
cerr << "delta: " << deltas[k] << endl;
|
||||
|
||||
@ -64,9 +66,12 @@ void MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
ScoreComponentCollection featureValueDiff = oracleFeatureValues;
|
||||
featureValueDiff.MinusEquals(featureValues[i][j]);
|
||||
float modelScoreDiff = featureValueDiff.InnerProduct(currWeights);
|
||||
if (modelScoreDiff < losses[i][j]) {
|
||||
float loss = losses[i][j] * m_marginScaleFactor;
|
||||
if (modelScoreDiff < loss) {
|
||||
++violatedConstraintsAfter;
|
||||
}
|
||||
|
||||
cerr << "New model score difference: " << modelScoreDiff << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,10 +59,11 @@ namespace Mira {
|
||||
MiraOptimiser() :
|
||||
Optimiser() { }
|
||||
|
||||
MiraOptimiser(size_t n, bool hildreth, float clipping) :
|
||||
MiraOptimiser(size_t n, bool hildreth, float marginScaleFactor, float clipping) :
|
||||
Optimiser(),
|
||||
m_n(n),
|
||||
m_hildreth(hildreth),
|
||||
m_marginScaleFactor(marginScaleFactor),
|
||||
m_c(clipping) { }
|
||||
|
||||
~MiraOptimiser() {}
|
||||
@ -84,8 +85,12 @@ namespace Mira {
|
||||
// number of hypotheses used for each nbest list (number of hope, fear, best model translations)
|
||||
size_t m_n;
|
||||
|
||||
// whether or not to use the Hildreth algorithm in the optimisation step
|
||||
bool m_hildreth;
|
||||
|
||||
// scaling the margin to regularise updates
|
||||
float m_marginScaleFactor;
|
||||
|
||||
// clipping threshold to regularise updates
|
||||
float m_c;
|
||||
};
|
||||
|
@ -175,8 +175,12 @@ namespace Moses {
|
||||
if (i->first != DEFAULT_NAME) {
|
||||
value += get(DEFAULT_NAME);
|
||||
}
|
||||
if (i->first != DEFAULT_NAME && i->second != 0.0) {
|
||||
out << i->first << "=" << value << ", ";
|
||||
/*if (i->first != DEFAULT_NAME && i->second != 0.0) {
|
||||
out << value << ", ";
|
||||
out << i->first << "=" << value << ", ";
|
||||
}*/
|
||||
if (i->first != DEFAULT_NAME) {
|
||||
out << value << ", ";
|
||||
}
|
||||
}
|
||||
out << "}";
|
||||
|
Loading…
Reference in New Issue
Block a user