mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-11 19:27:11 +03:00
Introduce stopping criterion based on changes in average total weights
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3760 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
87141f9f89
commit
47d91e8b93
@ -25,7 +25,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#ifdef MPI_ENABLE
|
||||
#include "mpi.h"
|
||||
#include <boost/mpi.hpp>
|
||||
namespace mpi = boost::mpi;
|
||||
#endif
|
||||
@ -244,6 +243,10 @@ int main(int argc, char** argv) {
|
||||
// the result of accumulating and averaging weights over one epoch and possibly several processes
|
||||
ScoreComponentCollection averageTotalWeights;
|
||||
|
||||
ScoreComponentCollection averageTotalWeightsCurrent;
|
||||
ScoreComponentCollection averageTotalWeightsPrevious;
|
||||
ScoreComponentCollection averageTotalWeightsBeforePrevious;
|
||||
|
||||
// TODO: scaling of feature values for probabilistic features
|
||||
vector< ScoreComponentCollection> list_of_delta_h; // collect delta_h and loss for all examples of an epoch
|
||||
vector< float> list_of_losses;
|
||||
@ -463,7 +466,6 @@ int main(int argc, char** argv) {
|
||||
decoder->setWeights(averageWeights);
|
||||
}
|
||||
#endif
|
||||
|
||||
// dump weights?
|
||||
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
|
||||
// compute average weights per process over iterations
|
||||
@ -473,44 +475,32 @@ int main(int argc, char** argv) {
|
||||
else
|
||||
totalWeights.DivideEquals(iterationsThisEpoch);
|
||||
|
||||
#ifdef MPI_ENABLE
|
||||
if (rank == 0) {
|
||||
cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
|
||||
cerr << "Rank 0, total weights: " << totalWeights << endl;
|
||||
}
|
||||
|
||||
if (weightEpochDump + 1 == weightDumpFrequency){
|
||||
// last weight dump in epoch
|
||||
averageTotalWeightsBeforePrevious = averageTotalWeightsPrevious;
|
||||
averageTotalWeightsPrevious = averageTotalWeightsCurrent;
|
||||
}
|
||||
|
||||
#ifdef MPI_ENABLE
|
||||
// average across processes
|
||||
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
// or use weights from single process
|
||||
averageTotalWeights = totalWeights;
|
||||
#endif
|
||||
|
||||
// average across processes
|
||||
#ifdef MPI_ENABLE
|
||||
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
|
||||
if (rank == 0) {
|
||||
if (rank == 0 && !weightDumpStem.empty()) {
|
||||
// average and normalise weights
|
||||
averageTotalWeights.DivideEquals(size);
|
||||
averageTotalWeights.L1Normalise();
|
||||
cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
|
||||
// compute summed error after dumping weights
|
||||
float summedError = 0.0;
|
||||
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
|
||||
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
|
||||
}
|
||||
|
||||
cerr << "Rank 0: summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
}
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
// or use weights from single process
|
||||
averageTotalWeights = totalWeights;
|
||||
|
||||
// compute summed error after dumping weights
|
||||
float summedError = 0.0;
|
||||
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
|
||||
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
|
||||
}
|
||||
|
||||
cerr << "summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
#endif
|
||||
if (!weightDumpStem.empty()) {
|
||||
ostringstream filename;
|
||||
filename << weightDumpStem << "_" << epoch;
|
||||
if (weightDumpFrequency > 1) {
|
||||
@ -519,6 +509,48 @@ int main(int argc, char** argv) {
|
||||
|
||||
VERBOSE(1, "Dumping weights for epoch " << epoch << " to " << filename.str() << endl);
|
||||
averageTotalWeights.Save(filename.str());
|
||||
|
||||
if (weightEpochDump + 1 == weightDumpFrequency){
|
||||
// last weight dump in epoch
|
||||
|
||||
// compute summed error
|
||||
float summedError = 0;
|
||||
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
|
||||
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
|
||||
}
|
||||
cerr << "Rank 0: summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
|
||||
// compare new average weights with previous weights
|
||||
averageTotalWeightsCurrent = averageTotalWeights;
|
||||
ScoreComponentCollection firstDiff(averageTotalWeightsCurrent);
|
||||
firstDiff.MinusEquals(averageTotalWeightsPrevious);
|
||||
ScoreComponentCollection secondDiff(averageTotalWeightsCurrent);
|
||||
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
|
||||
|
||||
// check whether stopping criterion has been reached
|
||||
// (both difference vectors must have all weight changes smaller than 0.01)
|
||||
bool reached = true;
|
||||
FVector changes1 = firstDiff.GetScoresVector();
|
||||
FVector changes2 = secondDiff.GetScoresVector();
|
||||
FVector::const_iterator iterator1 = changes1.cbegin();
|
||||
FVector::const_iterator iterator2 = changes2.cbegin();
|
||||
while (iterator1 != changes1.cend()) {
|
||||
if ((*iterator1).second >= 0.01 || (*iterator2).second >= 0.01) {
|
||||
reached = false;
|
||||
break;
|
||||
}
|
||||
|
||||
++iterator1;
|
||||
++iterator2;
|
||||
}
|
||||
|
||||
if (reached) {
|
||||
// stop MIRA
|
||||
cerr << "\nStopping criterion has been reached.. stopping MIRA." << endl << endl;
|
||||
goto end;
|
||||
}
|
||||
}
|
||||
|
||||
++weightEpochDump;
|
||||
}
|
||||
}
|
||||
@ -528,14 +560,14 @@ int main(int argc, char** argv) {
|
||||
list_of_losses.clear();
|
||||
}
|
||||
|
||||
end:
|
||||
#ifdef MPI_ENABLE
|
||||
MPI_Finalize();
|
||||
#endif
|
||||
|
||||
if (rank == 0) {
|
||||
cerr << "Average total weights: " << averageTotalWeights << endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
now = time(0); // get current time
|
||||
tm = localtime(&now); // get struct filled out
|
||||
|
@ -79,6 +79,10 @@ public:
|
||||
return m_scores.load(filename);
|
||||
}
|
||||
|
||||
FVector GetScoresVector()
|
||||
{
|
||||
return m_scores;
|
||||
}
|
||||
|
||||
//! Set all values to 0.0
|
||||
void ZeroAll()
|
||||
|
Loading…
Reference in New Issue
Block a user