sanity check for updates-per-epoch, stopping criterion 'sum of updates this epoch'

git-svn-id: http://svn.statmt.org/repository/mira@3865 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-04-08 21:04:08 +00:00 committed by Ondrej Bojar
parent 4187e0f691
commit 58eb0d0052
2 changed files with 74 additions and 40 deletions

View File

@ -126,6 +126,7 @@ int main(int argc, char** argv) {
bool stop_approx_dev_bleu;
int updates_per_epoch;
bool averageWeights;
bool stop_optimal;
po::options_description desc("Allowed options");
desc.add_options()("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs")
@ -173,6 +174,7 @@ int main(int argc, char** argv) {
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases (or no more increases)")
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)")
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(false), "Stop when the results of optimization do not improve further")
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
@ -312,6 +314,7 @@ int main(int argc, char** argv) {
cerr << "print-feature-values: " << print_feature_values << endl;
cerr << "stop-dev-bleu: " << stop_dev_bleu << endl;
cerr << "stop-approx-dev-bleu: " << stop_approx_dev_bleu << endl;
cerr << "stop-optimal: " << stop_optimal << endl;
cerr << "stop-weights: " << weightConvergence << endl;
cerr << "updates-per-epoch: " << updates_per_epoch << endl;
cerr << "use-total-weights-for-pruning: " << averageWeights << endl;
@ -337,8 +340,8 @@ int main(int argc, char** argv) {
// print initial weights
cerr << "Rank " << rank << ", initial weights: " << decoder->getWeights() << endl;
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
size_t numberCumulativeWeights = 0;
size_t numberCumulativeWeightsThisEpoch = 0;
size_t numberOfUpdates = 0;
size_t numberOfUpdatesThisEpoch = 0;
time_t now = time(0); // get current time
struct tm* tm = localtime(&now); // get struct filled out
@ -359,15 +362,27 @@ int main(int argc, char** argv) {
float prevAverageApproxBleu = 0;
float beforePrevAverageApproxBleu = 0;
bool stop = false;
bool weightsUpdated;
size_t sumViolConstAfterOpt;
float sumErrorAfterOpt;
float *sendbuf, *recvbuf;
sendbuf = (float *) malloc(sizeof(float));
recvbuf = (float *) malloc(sizeof(float));
// Note: make sure that the variable mosesWeights always holds the current decoder weights
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
cerr << "\nRank " << rank << ", epoch " << epoch << endl;
bool weightsUpdated = false;
numberCumulativeWeightsThisEpoch = 0;
// track whether there is any weight update this epoch
weightsUpdated = false;
// sum of violated constraints and error after optimization
sumViolConstAfterOpt = 0;
sumErrorAfterOpt = 0;
// sum of approx. sentence bleu scores per epoch
summedApproxBleu = 0;
numberOfUpdatesThisEpoch = 0;
// Sum up weights over one epoch, final average uses weights from last epoch
if (!accumulateWeights) {
cumulativeWeights.ZeroAll();
@ -536,14 +551,14 @@ int main(int argc, char** argv) {
}
cumulativeWeights.PlusEquals(mosesWeights);
++numberCumulativeWeights;
++numberCumulativeWeightsThisEpoch;
++numberOfUpdates;
++numberOfUpdatesThisEpoch;
if (averageWeights) {
ScoreComponentCollection averageWeights(cumulativeWeights);
if (accumulateWeights) {
averageWeights.DivideEquals(numberCumulativeWeights);
averageWeights.DivideEquals(numberOfUpdates);
} else {
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
averageWeights.DivideEquals(numberOfUpdatesThisEpoch);
}
mosesWeights = averageWeights;
@ -593,38 +608,43 @@ int main(int argc, char** argv) {
mosesWeights = decoder->getWeights();
ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates();
cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
mosesWeights.PlusEquals(accumulatedUpdates);
((MiraOptimiser*) optimiser)->resetAccumulatedUpdates();
if (accumulatedUpdates.GetWeightedScore() != 0) {
mosesWeights.PlusEquals(accumulatedUpdates);
((MiraOptimiser*) optimiser)->resetAccumulatedUpdates();
if (normaliseWeights) {
mosesWeights.L1Normalise();
}
cumulativeWeights.PlusEquals(mosesWeights);
++numberCumulativeWeights;
++numberCumulativeWeightsThisEpoch;
if (averageWeights) {
ScoreComponentCollection averageWeights(cumulativeWeights);
if (accumulateWeights) {
averageWeights.DivideEquals(numberCumulativeWeights);
} else {
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
if (normaliseWeights) {
mosesWeights.L1Normalise();
}
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
cumulativeWeights.PlusEquals(mosesWeights);
++numberOfUpdates;
++numberOfUpdatesThisEpoch;
if (averageWeights) {
ScoreComponentCollection averageWeights(cumulativeWeights);
if (accumulateWeights) {
averageWeights.DivideEquals(numberOfUpdates);
} else {
averageWeights.DivideEquals(numberOfUpdatesThisEpoch);
}
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
}
else {
cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl;
}
decoder->setWeights(mosesWeights);
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
}
else {
cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl;
cerr << "Rank " << rank << ", cumulative update is empty.." << endl;
}
decoder->setWeights(mosesWeights);
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
}
// mix weights?
@ -663,9 +683,9 @@ int main(int argc, char** argv) {
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
ScoreComponentCollection tmpAverageWeights(cumulativeWeights);
if (accumulateWeights) {
tmpAverageWeights.DivideEquals(numberCumulativeWeights);
tmpAverageWeights.DivideEquals(numberOfUpdates);
} else {
tmpAverageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
tmpAverageWeights.DivideEquals(numberOfUpdatesThisEpoch);
}
#ifdef MPI_ENABLE
@ -709,11 +729,22 @@ int main(int argc, char** argv) {
}// end averaging and dumping total weights
} // end of shard loop, end of this epoch
if (!weightsUpdated) {
size_t sumUpdates;
#ifdef MPI_ENABLE
mpi::reduce(world, numberOfUpdatesThisEpoch, sumUpdates, SCCPlus(), 0);
#endif
#ifndef MPI_ENABLE
sumUpdates = numberOfUpdatesThisEpoch;
#endif
if (sumUpdates == 0) {
cerr << "\nNo weight updates during this epoch.. stopping." << endl;
stop = true;
}
else {
if (stop_optimal) {
//TODO
}
if (devBleu) {
// calculate bleu score of dev set
vector<float> bleuAndRatio = decoder->calculateBleuOfCorpus(allBestModelScore, all_ref_ids, epoch, rank);
@ -769,7 +800,7 @@ int main(int argc, char** argv) {
}
// average approximate sentence bleu across processes
sendbuf[0] = summedApproxBleu/numberCumulativeWeightsThisEpoch;
sendbuf[0] = summedApproxBleu/numberOfUpdatesThisEpoch;
recvbuf[0] = 0;
MPI_Reduce(sendbuf, recvbuf, 1, MPI_FLOAT, MPI_SUM, 0, world);
if (rank == 0) {
@ -783,7 +814,7 @@ int main(int argc, char** argv) {
#ifndef MPI_ENABLE
averageBleu = bleu;
cerr << "Average Bleu (dev) after epoch " << epoch << ": " << averageBleu << endl;
averageApproxBleu = summedApproxBleu / numberCumulativeWeightsThisEpoch;
averageApproxBleu = summedApproxBleu / numberOfUpdatesThisEpoch;
cerr << "Average approx. sentence Bleu (dev) after epoch " << epoch << ": " << averageApproxBleu << endl;
#endif
if (rank == 0) {

View File

@ -202,13 +202,16 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
summedUpdate.PlusEquals(featureValueDiffs[k]);
}
} else {
cerr << "No constraint violated for this batch" << endl;
cerr << "Rank " << rank << ", no constraint violated for this batch" << endl;
return 1;
}
// sanity check: still violated constraints after optimisation?
ScoreComponentCollection newWeights(currWeights);
newWeights.PlusEquals(summedUpdate);
if (updates_per_epoch > 0) {
newWeights.PlusEquals(m_accumulatedUpdates);
}
int violatedConstraintsAfter = 0;
float newDistanceFromOptimum = 0;
for (size_t i = 0; i < featureValues.size(); ++i) {