fix average moses weights

git-svn-id: http://svn.statmt.org/repository/mira@3856 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-03-29 17:08:07 +00:00 committed by Ondrej Bojar
parent 3b760065ec
commit 2dfe7163b8

View File

@ -335,11 +335,10 @@ int main(int argc, char** argv) {
//Main loop:
// print initial weights
ScoreComponentCollection initialWeights = decoder->getWeights();
cerr << "weights: " << initialWeights << endl;
ScoreComponentCollection cumulativeWeights(initialWeights); // collect weights per epoch to produce an average
size_t numberCumulativeWeights = 1;
size_t numberCumulativeWeightsThisEpoch = 1;
cerr << "Rank " << rank << ", initial weights: " << decoder->getWeights() << endl;
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
size_t numberCumulativeWeights = 0;
size_t numberCumulativeWeightsThisEpoch = 0;
time_t now = time(0); // get current time
struct tm* tm = localtime(&now); // get struct filled out
@ -363,14 +362,14 @@ int main(int argc, char** argv) {
float *sendbuf, *recvbuf;
sendbuf = (float *) malloc(sizeof(float));
recvbuf = (float *) malloc(sizeof(float));
// Note: make sure that the variable mosesWeights always holds the current decoder weights
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
cerr << "\nEpoch " << epoch << endl;
numberCumulativeWeightsThisEpoch = 1;
numberCumulativeWeightsThisEpoch = 0;
summedApproxBleu = 0;
// Sum up weights over one epoch, final average uses weights from last epoch
if (!accumulateWeights) {
cumulativeWeights.ZeroAll();
cumulativeWeights = decoder->getWeights();
}
// number of weight dumps this epoch
@ -475,8 +474,7 @@ int main(int argc, char** argv) {
// set weight for bleu feature to 0
const vector<const ScoreProducer*> featureFunctions =
StaticData::Instance().GetTranslationSystem(
TranslationSystem::DEFAULT).GetFeatureFunctions();
StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT).GetFeatureFunctions();
mosesWeights.Assign(featureFunctions.back(), 0);
if (!hildreth && typeid(*optimiser) == typeid(MiraOptimiser)) {
@ -501,12 +499,11 @@ int main(int argc, char** argv) {
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
learning_rate, max_sentence_update, rank, updates_per_epoch);
// set decoder weights and accumulate weights
if (controlUpdates && updateStatus < 0) {
// TODO: could try to repeat hildreth with more slack
cerr << "update ignored!" << endl;
} else if (updates_per_epoch == -1) {
// if updates are not accumulated, apply new weights now
// apply new weights now
if (normaliseWeights) {
mosesWeights.L1Normalise();
cerr << "\nRank " << rank << ", new weights (normalised): " << mosesWeights << endl;
@ -524,15 +521,16 @@ int main(int argc, char** argv) {
} else {
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
}
decoder->setWeights(averageWeights);
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", average weights: " << averageWeights << endl;
}
else {
decoder->setWeights(mosesWeights);
}
// set new Moses weights (averaged or not)
decoder->setWeights(mosesWeights);
// compute difference to old weights
ScoreComponentCollection weightDifference(decoder->getWeights());
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
@ -582,7 +580,7 @@ int main(int argc, char** argv) {
// apply accumulated updates
if (makeUpdate && typeid(*optimiser) == typeid(MiraOptimiser)) {
ScoreComponentCollection mosesWeights = decoder->getWeights();
mosesWeights = decoder->getWeights();
ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates();
cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
mosesWeights.PlusEquals(accumulatedUpdates);
@ -590,21 +588,34 @@ int main(int argc, char** argv) {
if (normaliseWeights) {
mosesWeights.L1Normalise();
cerr << "Rank " << rank << ", new weights (normalised) during epoch " << epoch << ": " << mosesWeights << endl;
cerr << "Rank " << rank << ", weights (normalised) after applying cumulative update: " << mosesWeights << endl;
} else {
cerr << "Rank " << rank << ", new weights during epoch " << epoch << ": " << mosesWeights << endl;
cerr << "Rank " << rank << ", weights after applying cumulative update: " << mosesWeights << endl;
}
decoder->setWeights(mosesWeights);
cumulativeWeights.PlusEquals(mosesWeights);
++numberCumulativeWeights;
++numberCumulativeWeightsThisEpoch;
if (useAverageWeightsForPruning) {
ScoreComponentCollection averageWeights(cumulativeWeights);
if (accumulateWeights) {
averageWeights.DivideEquals(numberCumulativeWeights);
} else {
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
}
mosesWeights = averageWeights;
cerr << "Rank " << rank << ", average weights after applying cumulative update: " << averageWeights << endl;
}
decoder->setWeights(mosesWeights);
}
// mix weights?
if (shardPosition % (shard.size() / mixingFrequency) == 0) {
ScoreComponentCollection mixedWeights;
#ifdef MPI_ENABLE
ScoreComponentCollection mixedWeights;
cerr << "\nRank " << rank << ", before mixing: " << mosesWeights << endl;
// collect all weights in mixedWeights and divide by number of processes
@ -626,9 +637,10 @@ int main(int argc, char** argv) {
// broadcast average weights from process 0
mpi::broadcast(world, mixedWeights, 0);
decoder->setWeights(mixedWeights);
mosesWeights = mixedWeights;
#endif
#ifndef MPI_ENABLE
mixedWeights = mosesWeights;
cerr << "\nRank " << rank << ", no mixing, weights: " << mosesWeights << endl;
#endif
} // end mixing