mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
track the number of violated constraints and constraint changes
git-svn-id: http://svn.statmt.org/repository/mira@3869 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
ac307d9e8c
commit
16edd0b9bc
@ -193,7 +193,7 @@ int main(int argc, char** argv) {
|
||||
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases (or no more increases)")
|
||||
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)")
|
||||
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
|
||||
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(false), "Stop when the results of optimization do not improve further")
|
||||
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(true), "Stop when the results of optimization do not improve further")
|
||||
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
|
||||
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
|
||||
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
|
||||
@ -381,9 +381,10 @@ int main(int argc, char** argv) {
|
||||
float prevAverageApproxBleu = 0;
|
||||
float beforePrevAverageApproxBleu = 0;
|
||||
bool stop = false;
|
||||
bool weightsUpdated;
|
||||
size_t sumViolConstAfterOpt;
|
||||
float sumErrorAfterOpt;
|
||||
size_t sumStillViolatedConstraints;
|
||||
size_t sumStillViolatedConstraints_lastEpoch = 0;
|
||||
size_t sumConstraintChangeAbs;
|
||||
size_t sumConstraintChangeAbs_lastEpoch = 0;
|
||||
float *sendbuf, *recvbuf;
|
||||
sendbuf = (float *) malloc(sizeof(float));
|
||||
recvbuf = (float *) malloc(sizeof(float));
|
||||
@ -391,12 +392,9 @@ int main(int argc, char** argv) {
|
||||
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
|
||||
cerr << "\nRank " << rank << ", epoch " << epoch << endl;
|
||||
|
||||
// track whether there is any weight update this epoch
|
||||
weightsUpdated = false;
|
||||
|
||||
// sum of violated constraints and error after optimization
|
||||
sumViolConstAfterOpt = 0;
|
||||
sumErrorAfterOpt = 0;
|
||||
// sum of violated constraints
|
||||
sumStillViolatedConstraints = 0;
|
||||
sumConstraintChangeAbs = 0;
|
||||
|
||||
// sum of approx. sentence bleu scores per epoch
|
||||
summedApproxBleu = 0;
|
||||
@ -550,18 +548,19 @@ int main(int argc, char** argv) {
|
||||
// run optimiser on batch
|
||||
cerr << "\nRank " << rank << ", run optimiser:" << endl;
|
||||
ScoreComponentCollection oldWeights(mosesWeights);
|
||||
int updateStatus = optimiser->updateWeights(mosesWeights, featureValues,
|
||||
vector<int> update_status = optimiser->updateWeights(mosesWeights, featureValues,
|
||||
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
|
||||
learning_rate, max_sentence_update, rank, updates_per_epoch, controlUpdates);
|
||||
|
||||
if (updateStatus == 1) {
|
||||
if (update_status[0] == 1) {
|
||||
cerr << "Rank " << rank << ", no update for batch" << endl;
|
||||
}
|
||||
else if (updateStatus == -1) {
|
||||
else if (update_status[0] == -1) {
|
||||
cerr << "Rank " << rank << ", update ignored" << endl;
|
||||
}
|
||||
else {
|
||||
weightsUpdated = true;
|
||||
sumConstraintChangeAbs += abs(update_status[1] - update_status[2]);
|
||||
sumStillViolatedConstraints += update_status[2];
|
||||
|
||||
if (updates_per_epoch == -1) {
|
||||
// pass new weights to decoder
|
||||
@ -773,7 +772,14 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
if (stop_optimal) {
|
||||
//TODO
|
||||
if (epoch > 0) {
|
||||
if (sumConstraintChangeAbs_lastEpoch == sumConstraintChangeAbs && sumStillViolatedConstraints_lastEpoch == sumStillViolatedConstraints) {
|
||||
cerr << "Epoch " << epoch << ", sum of violated constraints and constraint changes has stayed the same: " << sumStillViolatedConstraints << ", " << sumConstraintChangeAbs << endl;
|
||||
}
|
||||
}
|
||||
|
||||
sumConstraintChangeAbs_lastEpoch = sumConstraintChangeAbs;
|
||||
sumStillViolatedConstraints_lastEpoch = sumStillViolatedConstraints;
|
||||
}
|
||||
|
||||
if (!stop) {
|
||||
|
@ -6,7 +6,7 @@ using namespace std;
|
||||
|
||||
namespace Mira {
|
||||
|
||||
int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
const vector<vector<ScoreComponentCollection> >& featureValues,
|
||||
const vector<vector<float> >& losses,
|
||||
const vector<std::vector<float> >& bleuScores, const vector<
|
||||
@ -203,7 +203,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
}
|
||||
} else {
|
||||
cerr << "Rank " << rank << ", check, no constraint violated for this batch" << endl;
|
||||
return 1;
|
||||
vector<int> status(3);
|
||||
status[0] = 1;
|
||||
status[1] = 0;
|
||||
status[2] = 0;
|
||||
return status;
|
||||
}
|
||||
|
||||
// sanity check: still violated constraints after optimisation?
|
||||
@ -236,7 +240,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
float distanceChange = oldDistanceFromOptimum - newDistanceFromOptimum;
|
||||
cerr << "Rank " << rank << ", check, there are still violated constraints, the distance change is: " << distanceChange << endl;
|
||||
if (controlUpdates && (violatedConstraintsBefore - violatedConstraintsAfter) < 0 && distanceChange < 0) {
|
||||
return -1;
|
||||
vector<int> statusPlus(3);
|
||||
statusPlus[0] = -1;
|
||||
statusPlus[1] = violatedConstraintsBefore;
|
||||
statusPlus[2] = violatedConstraintsAfter;
|
||||
return statusPlus;
|
||||
}
|
||||
}
|
||||
|
||||
@ -265,7 +273,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
cerr << "Rank " << rank << ", weights after update: " << currWeights << endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
vector<int> statusPlus(3);
|
||||
statusPlus[0] = 0;
|
||||
statusPlus[1] = violatedConstraintsBefore;
|
||||
statusPlus[2] = violatedConstraintsAfter;
|
||||
return statusPlus;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ namespace Mira {
|
||||
class Optimiser {
|
||||
public:
|
||||
Optimiser() {}
|
||||
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
|
||||
const std::vector< std::vector<float> >& losses,
|
||||
const std::vector<std::vector<float> >& bleuScores,
|
||||
@ -45,7 +45,7 @@ namespace Mira {
|
||||
|
||||
class Perceptron : public Optimiser {
|
||||
public:
|
||||
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
|
||||
const std::vector< std::vector<float> >& losses,
|
||||
const std::vector<std::vector<float> >& bleuScores,
|
||||
@ -80,7 +80,7 @@ namespace Mira {
|
||||
|
||||
~MiraOptimiser() {}
|
||||
|
||||
virtual int updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
virtual std::vector<int> updateWeights(Moses::ScoreComponentCollection& weights,
|
||||
const std::vector< std::vector<Moses::ScoreComponentCollection> >& featureValues,
|
||||
const std::vector< std::vector<float> >& losses,
|
||||
const std::vector<std::vector<float> >& bleuScores,
|
||||
|
@ -24,7 +24,7 @@ using namespace std;
|
||||
|
||||
namespace Mira {
|
||||
|
||||
int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
|
||||
vector<int> Perceptron::updateWeights(ScoreComponentCollection& currWeights,
|
||||
const vector< vector<ScoreComponentCollection> >& featureValues,
|
||||
const vector< vector<float> >& losses,
|
||||
const vector< vector<float> >& bleuScores,
|
||||
@ -46,7 +46,9 @@ int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
vector<int> status(1);
|
||||
status[0] = 0;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user