From 58eb0d0052c0c6e7f761af74132408896ec48924 Mon Sep 17 00:00:00 2001 From: ehasler Date: Fri, 8 Apr 2011 21:04:08 +0000 Subject: [PATCH] sanity check for updates-per-epoch, stopping criterion 'sum of updates this epoch' git-svn-id: http://svn.statmt.org/repository/mira@3865 cc96ff50-19ce-11e0-b349-13d7f0bd23df --- mira/Main.cpp | 109 ++++++++++++++++++++++++++--------------- mira/MiraOptimiser.cpp | 5 +- 2 files changed, 74 insertions(+), 40 deletions(-) diff --git a/mira/Main.cpp b/mira/Main.cpp index a422caf4f..8b2df9296 100644 --- a/mira/Main.cpp +++ b/mira/Main.cpp @@ -126,6 +126,7 @@ int main(int argc, char** argv) { bool stop_approx_dev_bleu; int updates_per_epoch; bool averageWeights; + bool stop_optimal; po::options_description desc("Allowed options"); desc.add_options()("accumulate-most-violated-constraints", po::value(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example") ("accumulate-weights", po::value(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs") @@ -173,6 +174,7 @@ int main(int argc, char** argv) { ("stop-dev-bleu", po::value(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases (or no more increases)") ("stop-approx-dev-bleu", po::value(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)") ("stop-weights", po::value(&weightConvergence)->default_value(false), "Stop when weights converge") + ("stop-optimal", po::value(&stop_optimal)->default_value(false), "Stop when the results of optimization do not improve further") ("updates-per-epoch", po::value(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch") ("use-scaled-reference", po::value(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases") ("verbosity,v", po::value(&verbosity)->default_value(0), "Verbosity level") @@ -312,6 +314,7 @@ int main(int argc, char** argv) { cerr << "print-feature-values: " << print_feature_values << endl; cerr << "stop-dev-bleu: " << stop_dev_bleu << endl; cerr << "stop-approx-dev-bleu: " << stop_approx_dev_bleu << endl; + cerr << "stop-optimal: " << stop_optimal << endl; cerr << "stop-weights: " << weightConvergence << endl; cerr << "updates-per-epoch: " << updates_per_epoch << endl; cerr << "use-total-weights-for-pruning: " << averageWeights << endl; @@ -337,8 +340,8 @@ int main(int argc, char** argv) { // print initial weights cerr << "Rank " << rank << ", initial weights: " << decoder->getWeights() << endl; ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average - size_t numberCumulativeWeights = 0; - size_t numberCumulativeWeightsThisEpoch = 0; + size_t numberOfUpdates = 0; + size_t numberOfUpdatesThisEpoch = 0; time_t now = time(0); // get current time struct tm* tm = localtime(&now); // get struct filled out @@ -359,15 +362,27 @@ int main(int argc, char** argv) { float prevAverageApproxBleu = 0; float beforePrevAverageApproxBleu = 0; bool stop = false; + bool weightsUpdated; + size_t sumViolConstAfterOpt; + float sumErrorAfterOpt; float *sendbuf, *recvbuf; sendbuf = (float *) malloc(sizeof(float)); recvbuf = (float *) malloc(sizeof(float)); // Note: make sure that the variable mosesWeights always holds the current decoder weights for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) { cerr << "\nRank " << rank << ", epoch " << epoch << endl; - bool weightsUpdated = false; - numberCumulativeWeightsThisEpoch = 0; + + // track whether there is any weight update this epoch + weightsUpdated = false; + + // sum of violated constraints and error after optimization + sumViolConstAfterOpt = 0; + sumErrorAfterOpt = 0; + + // sum of approx. sentence bleu scores per epoch summedApproxBleu = 0; + + numberOfUpdatesThisEpoch = 0; // Sum up weights over one epoch, final average uses weights from last epoch if (!accumulateWeights) { cumulativeWeights.ZeroAll(); @@ -536,14 +551,14 @@ int main(int argc, char** argv) { } cumulativeWeights.PlusEquals(mosesWeights); - ++numberCumulativeWeights; - ++numberCumulativeWeightsThisEpoch; + ++numberOfUpdates; + ++numberOfUpdatesThisEpoch; if (averageWeights) { ScoreComponentCollection averageWeights(cumulativeWeights); if (accumulateWeights) { - averageWeights.DivideEquals(numberCumulativeWeights); + averageWeights.DivideEquals(numberOfUpdates); } else { - averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch); + averageWeights.DivideEquals(numberOfUpdatesThisEpoch); } mosesWeights = averageWeights; @@ -593,38 +608,43 @@ int main(int argc, char** argv) { mosesWeights = decoder->getWeights(); ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates(); cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl; - mosesWeights.PlusEquals(accumulatedUpdates); - ((MiraOptimiser*) optimiser)->resetAccumulatedUpdates(); + if (accumulatedUpdates.GetWeightedScore() != 0) { + mosesWeights.PlusEquals(accumulatedUpdates); + ((MiraOptimiser*) optimiser)->resetAccumulatedUpdates(); - if (normaliseWeights) { - mosesWeights.L1Normalise(); - } - - cumulativeWeights.PlusEquals(mosesWeights); - ++numberCumulativeWeights; - ++numberCumulativeWeightsThisEpoch; - - if (averageWeights) { - ScoreComponentCollection averageWeights(cumulativeWeights); - if (accumulateWeights) { - averageWeights.DivideEquals(numberCumulativeWeights); - } else { - averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch); + if (normaliseWeights) { + mosesWeights.L1Normalise(); } - mosesWeights = averageWeights; - cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl; + cumulativeWeights.PlusEquals(mosesWeights); + ++numberOfUpdates; + ++numberOfUpdatesThisEpoch; + + if (averageWeights) { + ScoreComponentCollection averageWeights(cumulativeWeights); + if (accumulateWeights) { + averageWeights.DivideEquals(numberOfUpdates); + } else { + averageWeights.DivideEquals(numberOfUpdatesThisEpoch); + } + + mosesWeights = averageWeights; + cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl; + } + else { + cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl; + } + + decoder->setWeights(mosesWeights); + + // compute difference to old weights + ScoreComponentCollection weightDifference(mosesWeights); + weightDifference.MinusEquals(oldWeights); + cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl; } else { - cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl; + cerr << "Rank " << rank << ", cumulative update is empty.." << endl; } - - decoder->setWeights(mosesWeights); - - // compute difference to old weights - ScoreComponentCollection weightDifference(mosesWeights); - weightDifference.MinusEquals(oldWeights); - cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl; } // mix weights? @@ -663,9 +683,9 @@ int main(int argc, char** argv) { if (shardPosition % (shard.size() / weightDumpFrequency) == 0) { ScoreComponentCollection tmpAverageWeights(cumulativeWeights); if (accumulateWeights) { - tmpAverageWeights.DivideEquals(numberCumulativeWeights); + tmpAverageWeights.DivideEquals(numberOfUpdates); } else { - tmpAverageWeights.DivideEquals(numberCumulativeWeightsThisEpoch); + tmpAverageWeights.DivideEquals(numberOfUpdatesThisEpoch); } #ifdef MPI_ENABLE @@ -709,11 +729,22 @@ int main(int argc, char** argv) { }// end averaging and dumping total weights } // end of shard loop, end of this epoch - if (!weightsUpdated) { + size_t sumUpdates; +#ifdef MPI_ENABLE + mpi::reduce(world, numberOfUpdatesThisEpoch, sumUpdates, SCCPlus(), 0); +#endif +#ifndef MPI_ENABLE + sumUpdates = numberOfUpdatesThisEpoch; +#endif + if (sumUpdates == 0) { cerr << "\nNo weight updates during this epoch.. stopping." << endl; stop = true; } else { + if (stop_optimal) { + //TODO + } + if (devBleu) { // calculate bleu score of dev set vector bleuAndRatio = decoder->calculateBleuOfCorpus(allBestModelScore, all_ref_ids, epoch, rank); @@ -769,7 +800,7 @@ int main(int argc, char** argv) { } // average approximate sentence bleu across processes - sendbuf[0] = summedApproxBleu/numberCumulativeWeightsThisEpoch; + sendbuf[0] = summedApproxBleu/numberOfUpdatesThisEpoch; recvbuf[0] = 0; MPI_Reduce(sendbuf, recvbuf, 1, MPI_FLOAT, MPI_SUM, 0, world); if (rank == 0) { @@ -783,7 +814,7 @@ int main(int argc, char** argv) { #ifndef MPI_ENABLE averageBleu = bleu; cerr << "Average Bleu (dev) after epoch " << epoch << ": " << averageBleu << endl; - averageApproxBleu = summedApproxBleu / numberCumulativeWeightsThisEpoch; + averageApproxBleu = summedApproxBleu / numberOfUpdatesThisEpoch; cerr << "Average approx. sentence Bleu (dev) after epoch " << epoch << ": " << averageApproxBleu << endl; #endif if (rank == 0) { diff --git a/mira/MiraOptimiser.cpp b/mira/MiraOptimiser.cpp index cc7dfa4eb..2ae03c5d2 100644 --- a/mira/MiraOptimiser.cpp +++ b/mira/MiraOptimiser.cpp @@ -202,13 +202,16 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights, summedUpdate.PlusEquals(featureValueDiffs[k]); } } else { - cerr << "No constraint violated for this batch" << endl; + cerr << "Rank " << rank << ", no constraint violated for this batch" << endl; return 1; } // sanity check: still violated constraints after optimisation? ScoreComponentCollection newWeights(currWeights); newWeights.PlusEquals(summedUpdate); + if (updates_per_epoch > 0) { + newWeights.PlusEquals(m_accumulatedUpdates); + } int violatedConstraintsAfter = 0; float newDistanceFromOptimum = 0; for (size_t i = 0; i < featureValues.size(); ++i) {