include precision option for constraint checking, fix some logs, ignore update when constraint change == 0 and error increases

git-svn-id: http://svn.statmt.org/repository/mira@3888 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-05-04 18:26:20 +00:00 committed by Ondrej Bojar
parent 921735593d
commit 46cf77fa79
3 changed files with 36 additions and 29 deletions

View File

@ -186,6 +186,7 @@ int main(int argc, char** argv) {
vector<string> burnInReferenceFiles;
bool sentenceLevelBleu;
float bleuScoreWeight;
float precision;
po::options_description desc("Allowed options");
desc.add_options()
("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
@ -231,6 +232,7 @@ int main(int argc, char** argv) {
("normalise", po::value<bool>(&normaliseWeights)->default_value(false), "Whether to normalise the updated weights before passing them to the decoder")
("only-violated-constraints", po::value<bool>(&onlyViolatedConstraints)->default_value(false), "Add only violated constraints to the optimisation problem")
("past-and-current-constraints", po::value<bool>(&pastAndCurrentConstraints)->default_value(false), "Accumulate most violated constraint per example and use them along all current constraints")
("precision", po::value<float>(&precision)->default_value(0), "Precision when comparing left and right hand side of constraints")
("print-feature-values", po::value<bool>(&print_feature_values)->default_value(false), "Print out feature values")
("reference-files,r", po::value<vector<string> >(&referenceFiles), "Reference translation files for training")
("scale-by-input-length", po::value<bool>(&scaleByInputLength)->default_value(true), "Scale the BLEU score by a history of the input lengths")
@ -811,7 +813,7 @@ int main(int argc, char** argv) {
cerr << "Optimising using Mira" << endl;
optimiser = new MiraOptimiser(n, hildreth, marginScaleFactor,
onlyViolatedConstraints, slack, weightedLossFunction, maxNumberOracles,
accumulateMostViolatedConstraints, pastAndCurrentConstraints, order.size());
accumulateMostViolatedConstraints, pastAndCurrentConstraints, order.size(), precision);
if (hildreth) {
cerr << "Using Hildreth's optimisation algorithm.." << endl;
}
@ -1009,7 +1011,7 @@ int main(int argc, char** argv) {
// optionally print out the feature values
if (print_feature_values) {
cerr << "\nRank " << rank << ", feature values: " << endl;
cerr << "\nRank " << rank << ", epoch " << epoch << ", feature values: " << endl;
for (size_t i = 0; i < featureValues.size(); ++i) {
for (size_t j = 0; j < featureValues[i].size(); ++j) {
cerr << featureValues[i][j] << endl;
@ -1019,7 +1021,7 @@ int main(int argc, char** argv) {
}
// run optimiser on batch
cerr << "\nRank " << rank << ", run optimiser:" << endl;
cerr << "\nRank " << rank << ", epoch " << epoch << ", run optimiser:" << endl;
ScoreComponentCollection oldWeights(mosesWeights);
vector<int> update_status;
update_status = optimiser->updateWeights(mosesWeights, featureValues,
@ -1054,10 +1056,10 @@ int main(int argc, char** argv) {
}
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", set new average weights: " << mosesWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", set new average weights: " << mosesWeights << endl;
}
else {
cerr << "Rank " << rank << ", set new weights: " << mosesWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", set new weights: " << mosesWeights << endl;
}
// set new Moses weights (averaged or not)
@ -1066,7 +1068,7 @@ int main(int argc, char** argv) {
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", weight difference: " << weightDifference << endl;
// get 1best model results with new weights (for each sentence in batch)
vector<float> bestModelNew;
@ -1084,13 +1086,13 @@ int main(int argc, char** argv) {
// update history (for approximate document Bleu)
if (historyOf1best) {
for (size_t i = 0; i < oneBests.size(); ++i) {
cerr << "Rank " << rank << ", 1best length: " << oneBests[i].size() << " ";
cerr << "Rank " << rank << ", epoch " << epoch << ", 1best length: " << oneBests[i].size() << " ";
}
decoder->updateHistory(oneBests, inputLengths, ref_ids, rank, epoch);
}
else {
for (size_t i = 0; i < oracles.size(); ++i) {
cerr << "Rank " << rank << ", oracle length: " << oracles[i].size() << " ";
cerr << "Rank " << rank << ", epoch " << epoch << ", oracle length: " << oracles[i].size() << " ";
}
decoder->updateHistory(oracles, inputLengths, ref_ids, rank, epoch);
}
@ -1109,7 +1111,7 @@ int main(int argc, char** argv) {
if (makeUpdate && typeid(*optimiser) == typeid(MiraOptimiser)) {
mosesWeights = decoder->getWeights();
ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates();
cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
cerr << "\nRank " << rank << ", epoch " << epoch << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
if (accumulatedUpdates.GetWeightedScore() != 0) {
mosesWeights.PlusEquals(accumulatedUpdates);
((MiraOptimiser*) optimiser)->resetAccumulatedUpdates();
@ -1131,10 +1133,10 @@ int main(int argc, char** argv) {
}
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
}
else {
cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", set new weights after applying cumulative update: " << mosesWeights << endl;
}
decoder->setWeights(mosesWeights);
@ -1142,10 +1144,10 @@ int main(int argc, char** argv) {
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", weight difference: " << weightDifference << endl;
}
else {
cerr << "Rank " << rank << ", cumulative update is empty.." << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", cumulative update is empty.." << endl;
}
}

View File

@ -30,7 +30,7 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
for (size_t j = 0; j < m_oracles[sentenceId].size(); ++j) {
float currentWeightedScore = m_oracles[sentenceId][j].InnerProduct(currWeights);
if (currentWeightedScore == newWeightedScore) {
cerr << "Rank " << rank << ", Bleu score of oracle updated at batch position " << i << ", " << m_bleu_of_oracles[sentenceId][j] << " --> " << oracleBleuScores[j] << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", bleu score of oracle updated at batch position " << i << ", " << m_bleu_of_oracles[sentenceId][j] << " --> " << oracleBleuScores[j] << endl;
m_bleu_of_oracles[sentenceId][j] = oracleBleuScores[j];
updated = true;
break;
@ -95,8 +95,9 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
bool addConstraint = true;
float diff = loss - modelScoreDiff;
// cerr << "constraint: " << modelScoreDiff << " >= " << loss << endl;
if (diff > epsilon) {
if (diff > (epsilon + m_precision)) {
violated = true;
cerr << "Rank " << rank << ", epoch " << epoch << ", current violation: " << diff << " (loss: " << loss << ")" << endl;
}
else if (m_onlyViolatedConstraints) {
addConstraint = false;
@ -151,8 +152,9 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
bool violated = false;
bool addConstraint = true;
float diff = m_losses[i] - modelScoreDiff;
if (diff > epsilon) {
if (diff > (epsilon + m_precision)) {
violated = true;
cerr << "Rank " << rank << ", epoch " << epoch << ", past violation: " << diff << " (loss: " << m_losses[i] << ")" << endl;
}
else if (m_onlyViolatedConstraints) {
addConstraint = false;
@ -254,7 +256,7 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
float modelScoreDiff = featureValueDiffs[i].InnerProduct(newWeights);
float loss = all_losses[i];
float diff = loss - modelScoreDiff;
if (diff > epsilon) {
if (diff > (epsilon + m_precision)) {
++violatedConstraintsAfter;
newDistanceFromOptimum += diff;
}
@ -264,38 +266,38 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
if (controlUpdates && violatedConstraintsAfter > 0) {
float distanceChange = oldDistanceFromOptimum - newDistanceFromOptimum;
if ((violatedConstraintsBefore - violatedConstraintsAfter) < 0 && distanceChange < 0) {
if ((violatedConstraintsBefore - violatedConstraintsAfter) <= 0 && distanceChange < 0) {
vector<int> statusPlus(3);
statusPlus[0] = -1;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
statusPlus[1] = -1;
statusPlus[2] = -1;
return statusPlus;
}
}
// Apply learning rate (fixed or flexible)
if (learning_rate != 1) {
cerr << "Rank " << rank << ", update before applying learning rate: " << summedUpdate << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", update before applying learning rate: " << summedUpdate << endl;
summedUpdate.MultiplyEquals(learning_rate);
cerr << "Rank " << rank << ", update after applying learning rate: " << summedUpdate << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", update after applying learning rate: " << summedUpdate << endl;
}
// Apply threshold scaling
if (max_sentence_update != -1) {
cerr << "Rank " << rank << ", update before scaling to max-sentence-update: " << summedUpdate << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", update before scaling to max-sentence-update: " << summedUpdate << endl;
summedUpdate.ThresholdScaling(max_sentence_update);
cerr << "Rank " << rank << ", update after scaling to max-sentence-update: " << summedUpdate << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", update after scaling to max-sentence-update: " << summedUpdate << endl;
}
// Apply update to weight vector or store it for later
if (updates_per_epoch > 0) {
m_accumulatedUpdates.PlusEquals(summedUpdate);
cerr << "Rank " << rank << ", new accumulated updates:" << m_accumulatedUpdates << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", new accumulated updates:" << m_accumulatedUpdates << endl;
} else {
// apply update to weight vector
cerr << "Rank " << rank << ", weights before update: " << currWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", weights before update: " << currWeights << endl;
currWeights.PlusEquals(summedUpdate);
cerr << "Rank " << rank << ", weights after update: " << currWeights << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", weights after update: " << currWeights << endl;
}
vector<int> statusPlus(3);

View File

@ -88,7 +88,7 @@ namespace Mira {
MiraOptimiser() :
Optimiser() { }
MiraOptimiser(size_t n, bool hildreth, float marginScaleFactor, bool onlyViolatedConstraints, float slack, size_t weightedLossFunction, size_t maxNumberOracles, bool accumulateMostViolatedConstraints, bool pastAndCurrentConstraints, size_t exampleSize) :
MiraOptimiser(size_t n, bool hildreth, float marginScaleFactor, bool onlyViolatedConstraints, float slack, size_t weightedLossFunction, size_t maxNumberOracles, bool accumulateMostViolatedConstraints, bool pastAndCurrentConstraints, size_t exampleSize, float precision) :
Optimiser(),
m_n(n),
m_hildreth(hildreth),
@ -100,7 +100,8 @@ namespace Mira {
m_accumulateMostViolatedConstraints(accumulateMostViolatedConstraints),
m_pastAndCurrentConstraints(pastAndCurrentConstraints),
m_oracles(exampleSize),
m_bleu_of_oracles(exampleSize) { }
m_bleu_of_oracles(exampleSize),
m_precision(precision) { }
~MiraOptimiser() {}
@ -187,6 +188,8 @@ namespace Mira {
bool m_pastAndCurrentConstraints;
Moses::ScoreComponentCollection m_accumulatedUpdates;
float m_precision;
};
}