track the number of violated constraints and constraint changes

git-svn-id: http://svn.statmt.org/repository/mira@3869 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-04-10 22:07:41 +00:00 committed by Ondrej Bojar
parent ac307d9e8c
commit 16edd0b9bc
4 changed files with 44 additions and 24 deletions

View File

@ -193,7 +193,7 @@ int main(int argc, char** argv) {
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases (or no more increases)")
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)")
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(false), "Stop when the results of optimization do not improve further")
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(true), "Stop when the results of optimization do not improve further")
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
@ -381,9 +381,10 @@ int main(int argc, char** argv) {
float prevAverageApproxBleu = 0;
float beforePrevAverageApproxBleu = 0;
bool stop = false;
bool weightsUpdated;
size_t sumViolConstAfterOpt;
float sumErrorAfterOpt;
size_t sumStillViolatedConstraints;
size_t sumStillViolatedConstraints_lastEpoch = 0;
size_t sumConstraintChangeAbs;
size_t sumConstraintChangeAbs_lastEpoch = 0;
float *sendbuf, *recvbuf;
sendbuf = (float *) malloc(sizeof(float));
recvbuf = (float *) malloc(sizeof(float));
@ -391,12 +392,9 @@ int main(int argc, char** argv) {
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
cerr << "\nRank " << rank << ", epoch " << epoch << endl;
// track whether there is any weight update this epoch
weightsUpdated = false;
// sum of violated constraints and error after optimization
sumViolConstAfterOpt = 0;
sumErrorAfterOpt = 0;
// sum of violated constraints
sumStillViolatedConstraints = 0;
sumConstraintChangeAbs = 0;
// sum of approx. sentence bleu scores per epoch
summedApproxBleu = 0;
@ -550,18 +548,19 @@ int main(int argc, char** argv) {
// run optimiser on batch
cerr << "\nRank " << rank << ", run optimiser:" << endl;
ScoreComponentCollection oldWeights(mosesWeights);
int updateStatus = optimiser->updateWeights(mosesWeights, featureValues,
vector<int> update_status = optimiser->updateWeights(mosesWeights, featureValues,
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
learning_rate, max_sentence_update, rank, updates_per_epoch, controlUpdates);
if (updateStatus == 1) {
if (update_status[0] == 1) {
cerr << "Rank " << rank << ", no update for batch" << endl;
}
else if (updateStatus == -1) {
else if (update_status[0] == -1) {
cerr << "Rank " << rank << ", update ignored" << endl;
}
else {
weightsUpdated = true;
sumConstraintChangeAbs += abs(update_status[1] - update_status[2]);
sumStillViolatedConstraints += update_status[2];
if (updates_per_epoch == -1) {
// pass new weights to decoder
@ -773,7 +772,14 @@ int main(int argc, char** argv) {
}
if (stop_optimal) {
//TODO
if (epoch > 0) {
if (sumConstraintChangeAbs_lastEpoch == sumConstraintChangeAbs && sumStillViolatedConstraints_lastEpoch == sumStillViolatedConstraints) {
cerr << "Epoch " << epoch << ", sum of violated constraints and constraint changes has stayed the same: " << sumStillViolatedConstraints << ", " << sumConstraintChangeAbs << endl;
}
}
sumConstraintChangeAbs_lastEpoch = sumConstraintChangeAbs;
sumStillViolatedConstraints_lastEpoch = sumStillViolatedConstraints;
}
if (!stop) {

View File

@ -6,7 +6,7 @@ using namespace std;
namespace Mira {
int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
const vector<vector<ScoreComponentCollection> >& featureValues,
const vector<vector<float> >& losses,
const vector<std::vector<float> >& bleuScores, const vector<
@ -203,7 +203,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
}
} else {
cerr << "Rank " << rank << ", check, no constraint violated for this batch" << endl;
return 1;
vector<int> status(3);
status[0] = 1;
status[1] = 0;
status[2] = 0;
return status;
}
// sanity check: still violated constraints after optimisation?
@ -236,7 +240,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
float distanceChange = oldDistanceFromOptimum - newDistanceFromOptimum;
cerr << "Rank " << rank << ", check, there are still violated constraints, the distance change is: " << distanceChange << endl;
if (controlUpdates && (violatedConstraintsBefore - violatedConstraintsAfter) < 0 && distanceChange < 0) {
return -1;
vector<int> statusPlus(3);
statusPlus[0] = -1;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
}
}
@ -265,7 +273,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
cerr << "Rank " << rank << ", weights after update: " << currWeights << endl;
}
return 0;
vector<int> statusPlus(3);
statusPlus[0] = 0;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
}
}

View File

@ -29,7 +29,7 @@ namespace Mira {
class Optimiser {
public:
Optimiser() {}
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
const std::vector< std::vector<float> >& losses,
const std::vector<std::vector<float> >& bleuScores,
@ -45,7 +45,7 @@ namespace Mira {
class Perceptron : public Optimiser {
public:
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
const std::vector< std::vector<float> >& losses,
const std::vector<std::vector<float> >& bleuScores,
@ -80,7 +80,7 @@ namespace Mira {
~MiraOptimiser() {}
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
const std::vector< std::vector<float> >& losses,
const std::vector<std::vector<float> >& bleuScores,

View File

@ -24,7 +24,7 @@ using namespace std;
namespace Mira {
int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
vector<int> Perceptron::updateWeights(ScoreComponentCollection& currWeights,
const vector< vector<ScoreComponentCollection> >& featureValues,
const vector< vector<float> >& losses,
const vector< vector<float> >& bleuScores,
@ -46,7 +46,9 @@ int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
}
}
return 0;
vector<int> status(1);
status[0] = 0;
return status;
}
}