mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2025-01-01 00:12:27 +03:00
sanity check for updates-per-epoch, stopping criterion 'sum of updates this epoch'
git-svn-id: http://svn.statmt.org/repository/mira@3865 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
4187e0f691
commit
58eb0d0052
109
mira/Main.cpp
109
mira/Main.cpp
@ -126,6 +126,7 @@ int main(int argc, char** argv) {
|
||||
bool stop_approx_dev_bleu;
|
||||
int updates_per_epoch;
|
||||
bool averageWeights;
|
||||
bool stop_optimal;
|
||||
po::options_description desc("Allowed options");
|
||||
desc.add_options()("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
|
||||
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs")
|
||||
@ -173,6 +174,7 @@ int main(int argc, char** argv) {
|
||||
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases (or no more increases)")
|
||||
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)")
|
||||
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
|
||||
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(false), "Stop when the results of optimization do not improve further")
|
||||
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
|
||||
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
|
||||
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
|
||||
@ -312,6 +314,7 @@ int main(int argc, char** argv) {
|
||||
cerr << "print-feature-values: " << print_feature_values << endl;
|
||||
cerr << "stop-dev-bleu: " << stop_dev_bleu << endl;
|
||||
cerr << "stop-approx-dev-bleu: " << stop_approx_dev_bleu << endl;
|
||||
cerr << "stop-optimal: " << stop_optimal << endl;
|
||||
cerr << "stop-weights: " << weightConvergence << endl;
|
||||
cerr << "updates-per-epoch: " << updates_per_epoch << endl;
|
||||
cerr << "use-total-weights-for-pruning: " << averageWeights << endl;
|
||||
@ -337,8 +340,8 @@ int main(int argc, char** argv) {
|
||||
// print initial weights
|
||||
cerr << "Rank " << rank << ", initial weights: " << decoder->getWeights() << endl;
|
||||
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
|
||||
size_t numberCumulativeWeights = 0;
|
||||
size_t numberCumulativeWeightsThisEpoch = 0;
|
||||
size_t numberOfUpdates = 0;
|
||||
size_t numberOfUpdatesThisEpoch = 0;
|
||||
|
||||
time_t now = time(0); // get current time
|
||||
struct tm* tm = localtime(&now); // get struct filled out
|
||||
@ -359,15 +362,27 @@ int main(int argc, char** argv) {
|
||||
float prevAverageApproxBleu = 0;
|
||||
float beforePrevAverageApproxBleu = 0;
|
||||
bool stop = false;
|
||||
bool weightsUpdated;
|
||||
size_t sumViolConstAfterOpt;
|
||||
float sumErrorAfterOpt;
|
||||
float *sendbuf, *recvbuf;
|
||||
sendbuf = (float *) malloc(sizeof(float));
|
||||
recvbuf = (float *) malloc(sizeof(float));
|
||||
// Note: make sure that the variable mosesWeights always holds the current decoder weights
|
||||
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
|
||||
cerr << "\nRank " << rank << ", epoch " << epoch << endl;
|
||||
bool weightsUpdated = false;
|
||||
numberCumulativeWeightsThisEpoch = 0;
|
||||
|
||||
// track whether there is any weight update this epoch
|
||||
weightsUpdated = false;
|
||||
|
||||
// sum of violated constraints and error after optimization
|
||||
sumViolConstAfterOpt = 0;
|
||||
sumErrorAfterOpt = 0;
|
||||
|
||||
// sum of approx. sentence bleu scores per epoch
|
||||
summedApproxBleu = 0;
|
||||
|
||||
numberOfUpdatesThisEpoch = 0;
|
||||
// Sum up weights over one epoch, final average uses weights from last epoch
|
||||
if (!accumulateWeights) {
|
||||
cumulativeWeights.ZeroAll();
|
||||
@ -536,14 +551,14 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
++numberCumulativeWeights;
|
||||
++numberCumulativeWeightsThisEpoch;
|
||||
++numberOfUpdates;
|
||||
++numberOfUpdatesThisEpoch;
|
||||
if (averageWeights) {
|
||||
ScoreComponentCollection averageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
averageWeights.DivideEquals(numberCumulativeWeights);
|
||||
averageWeights.DivideEquals(numberOfUpdates);
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
averageWeights.DivideEquals(numberOfUpdatesThisEpoch);
|
||||
}
|
||||
|
||||
mosesWeights = averageWeights;
|
||||
@ -593,38 +608,43 @@ int main(int argc, char** argv) {
|
||||
mosesWeights = decoder->getWeights();
|
||||
ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates();
|
||||
cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
|
||||
mosesWeights.PlusEquals(accumulatedUpdates);
|
||||
((MiraOptimiser*) optimiser)->resetAccumulatedUpdates();
|
||||
if (accumulatedUpdates.GetWeightedScore() != 0) {
|
||||
mosesWeights.PlusEquals(accumulatedUpdates);
|
||||
((MiraOptimiser*) optimiser)->resetAccumulatedUpdates();
|
||||
|
||||
if (normaliseWeights) {
|
||||
mosesWeights.L1Normalise();
|
||||
}
|
||||
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
++numberCumulativeWeights;
|
||||
++numberCumulativeWeightsThisEpoch;
|
||||
|
||||
if (averageWeights) {
|
||||
ScoreComponentCollection averageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
averageWeights.DivideEquals(numberCumulativeWeights);
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
if (normaliseWeights) {
|
||||
mosesWeights.L1Normalise();
|
||||
}
|
||||
|
||||
mosesWeights = averageWeights;
|
||||
cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
++numberOfUpdates;
|
||||
++numberOfUpdatesThisEpoch;
|
||||
|
||||
if (averageWeights) {
|
||||
ScoreComponentCollection averageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
averageWeights.DivideEquals(numberOfUpdates);
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberOfUpdatesThisEpoch);
|
||||
}
|
||||
|
||||
mosesWeights = averageWeights;
|
||||
cerr << "Rank " << rank << ", set new average weights after applying cumulative update: " << mosesWeights << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl;
|
||||
}
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
|
||||
// compute difference to old weights
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "Rank " << rank << ", set new weights after applying cumulative update: " << mosesWeights << endl;
|
||||
cerr << "Rank " << rank << ", cumulative update is empty.." << endl;
|
||||
}
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
|
||||
// compute difference to old weights
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
}
|
||||
|
||||
// mix weights?
|
||||
@ -663,9 +683,9 @@ int main(int argc, char** argv) {
|
||||
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
|
||||
ScoreComponentCollection tmpAverageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
tmpAverageWeights.DivideEquals(numberCumulativeWeights);
|
||||
tmpAverageWeights.DivideEquals(numberOfUpdates);
|
||||
} else {
|
||||
tmpAverageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
tmpAverageWeights.DivideEquals(numberOfUpdatesThisEpoch);
|
||||
}
|
||||
|
||||
#ifdef MPI_ENABLE
|
||||
@ -709,11 +729,22 @@ int main(int argc, char** argv) {
|
||||
}// end averaging and dumping total weights
|
||||
} // end of shard loop, end of this epoch
|
||||
|
||||
if (!weightsUpdated) {
|
||||
size_t sumUpdates;
|
||||
#ifdef MPI_ENABLE
|
||||
mpi::reduce(world, numberOfUpdatesThisEpoch, sumUpdates, SCCPlus(), 0);
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
sumUpdates = numberOfUpdatesThisEpoch;
|
||||
#endif
|
||||
if (sumUpdates == 0) {
|
||||
cerr << "\nNo weight updates during this epoch.. stopping." << endl;
|
||||
stop = true;
|
||||
}
|
||||
else {
|
||||
if (stop_optimal) {
|
||||
//TODO
|
||||
}
|
||||
|
||||
if (devBleu) {
|
||||
// calculate bleu score of dev set
|
||||
vector<float> bleuAndRatio = decoder->calculateBleuOfCorpus(allBestModelScore, all_ref_ids, epoch, rank);
|
||||
@ -769,7 +800,7 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
// average approximate sentence bleu across processes
|
||||
sendbuf[0] = summedApproxBleu/numberCumulativeWeightsThisEpoch;
|
||||
sendbuf[0] = summedApproxBleu/numberOfUpdatesThisEpoch;
|
||||
recvbuf[0] = 0;
|
||||
MPI_Reduce(sendbuf, recvbuf, 1, MPI_FLOAT, MPI_SUM, 0, world);
|
||||
if (rank == 0) {
|
||||
@ -783,7 +814,7 @@ int main(int argc, char** argv) {
|
||||
#ifndef MPI_ENABLE
|
||||
averageBleu = bleu;
|
||||
cerr << "Average Bleu (dev) after epoch " << epoch << ": " << averageBleu << endl;
|
||||
averageApproxBleu = summedApproxBleu / numberCumulativeWeightsThisEpoch;
|
||||
averageApproxBleu = summedApproxBleu / numberOfUpdatesThisEpoch;
|
||||
cerr << "Average approx. sentence Bleu (dev) after epoch " << epoch << ": " << averageApproxBleu << endl;
|
||||
#endif
|
||||
if (rank == 0) {
|
||||
|
@ -202,13 +202,16 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
summedUpdate.PlusEquals(featureValueDiffs[k]);
|
||||
}
|
||||
} else {
|
||||
cerr << "No constraint violated for this batch" << endl;
|
||||
cerr << "Rank " << rank << ", no constraint violated for this batch" << endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// sanity check: still violated constraints after optimisation?
|
||||
ScoreComponentCollection newWeights(currWeights);
|
||||
newWeights.PlusEquals(summedUpdate);
|
||||
if (updates_per_epoch > 0) {
|
||||
newWeights.PlusEquals(m_accumulatedUpdates);
|
||||
}
|
||||
int violatedConstraintsAfter = 0;
|
||||
float newDistanceFromOptimum = 0;
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
|
Loading…
Reference in New Issue
Block a user