Introduce stopping criterion based on changes in average total weights

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3760 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
evahasler 2010-12-06 15:28:51 +00:00
parent 87141f9f89
commit 47d91e8b93
2 changed files with 68 additions and 32 deletions

View File

@ -25,7 +25,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <boost/program_options.hpp>
#ifdef MPI_ENABLE
#include "mpi.h"
#include <boost/mpi.hpp>
namespace mpi = boost::mpi;
#endif
@ -244,6 +243,10 @@ int main(int argc, char** argv) {
// the result of accumulating and averaging weights over one epoch and possibly several processes
ScoreComponentCollection averageTotalWeights;
ScoreComponentCollection averageTotalWeightsCurrent;
ScoreComponentCollection averageTotalWeightsPrevious;
ScoreComponentCollection averageTotalWeightsBeforePrevious;
// TODO: scaling of feature values for probabilistic features
vector< ScoreComponentCollection> list_of_delta_h; // collect delta_h and loss for all examples of an epoch
vector< float> list_of_losses;
@ -463,7 +466,6 @@ int main(int argc, char** argv) {
decoder->setWeights(averageWeights);
}
#endif
// dump weights?
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
// compute average weights per process over iterations
@ -473,44 +475,32 @@ int main(int argc, char** argv) {
else
totalWeights.DivideEquals(iterationsThisEpoch);
#ifdef MPI_ENABLE
if (rank == 0) {
cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
cerr << "Rank 0, total weights: " << totalWeights << endl;
}
if (weightEpochDump + 1 == weightDumpFrequency){
// last weight dump in epoch
averageTotalWeightsBeforePrevious = averageTotalWeightsPrevious;
averageTotalWeightsPrevious = averageTotalWeightsCurrent;
}
#ifdef MPI_ENABLE
// average across processes
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
#endif
#ifndef MPI_ENABLE
// or use weights from single process
averageTotalWeights = totalWeights;
#endif
// average across processes
#ifdef MPI_ENABLE
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
if (rank == 0) {
if (rank == 0 && !weightDumpStem.empty()) {
// average and normalise weights
averageTotalWeights.DivideEquals(size);
averageTotalWeights.L1Normalise();
cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
// compute summed error after dumping weights
float summedError = 0.0;
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
}
cerr << "Rank 0: summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
}
#endif
#ifndef MPI_ENABLE
// or use weights from single process
averageTotalWeights = totalWeights;
// compute summed error after dumping weights
float summedError = 0.0;
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
}
cerr << "summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
#endif
if (!weightDumpStem.empty()) {
ostringstream filename;
filename << weightDumpStem << "_" << epoch;
if (weightDumpFrequency > 1) {
@ -519,6 +509,48 @@ int main(int argc, char** argv) {
VERBOSE(1, "Dumping weights for epoch " << epoch << " to " << filename.str() << endl);
averageTotalWeights.Save(filename.str());
if (weightEpochDump + 1 == weightDumpFrequency){
// last weight dump in epoch
// compute summed error
float summedError = 0;
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
}
cerr << "Rank 0: summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
// compare new average weights with previous weights
averageTotalWeightsCurrent = averageTotalWeights;
ScoreComponentCollection firstDiff(averageTotalWeightsCurrent);
firstDiff.MinusEquals(averageTotalWeightsPrevious);
ScoreComponentCollection secondDiff(averageTotalWeightsCurrent);
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
// check whether stopping criterion has been reached
// (both difference vectors must have all weight changes smaller than 0.01)
bool reached = true;
FVector changes1 = firstDiff.GetScoresVector();
FVector changes2 = secondDiff.GetScoresVector();
FVector::const_iterator iterator1 = changes1.cbegin();
FVector::const_iterator iterator2 = changes2.cbegin();
while (iterator1 != changes1.cend()) {
if ((*iterator1).second >= 0.01 || (*iterator2).second >= 0.01) {
reached = false;
break;
}
++iterator1;
++iterator2;
}
if (reached) {
// stop MIRA
cerr << "\nStopping criterion has been reached.. stopping MIRA." << endl << endl;
goto end;
}
}
++weightEpochDump;
}
}
@ -528,14 +560,14 @@ int main(int argc, char** argv) {
list_of_losses.clear();
}
end:
#ifdef MPI_ENABLE
MPI_Finalize();
#endif
if (rank == 0) {
cerr << "Average total weights: " << averageTotalWeights << endl;
}
#endif
now = time(0); // get current time
tm = localtime(&now); // get struct filled out

View File

@ -79,6 +79,10 @@ public:
return m_scores.load(filename);
}
FVector GetScoresVector()
{
return m_scores;
}
//! Set all values to 0.0
void ZeroAll()