mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
fix average moses weights
git-svn-id: http://svn.statmt.org/repository/mira@3856 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
3b760065ec
commit
2dfe7163b8
@ -335,11 +335,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
//Main loop:
|
||||
// print initial weights
|
||||
ScoreComponentCollection initialWeights = decoder->getWeights();
|
||||
cerr << "weights: " << initialWeights << endl;
|
||||
ScoreComponentCollection cumulativeWeights(initialWeights); // collect weights per epoch to produce an average
|
||||
size_t numberCumulativeWeights = 1;
|
||||
size_t numberCumulativeWeightsThisEpoch = 1;
|
||||
cerr << "Rank " << rank << ", initial weights: " << decoder->getWeights() << endl;
|
||||
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
|
||||
size_t numberCumulativeWeights = 0;
|
||||
size_t numberCumulativeWeightsThisEpoch = 0;
|
||||
|
||||
time_t now = time(0); // get current time
|
||||
struct tm* tm = localtime(&now); // get struct filled out
|
||||
@ -363,14 +362,14 @@ int main(int argc, char** argv) {
|
||||
float *sendbuf, *recvbuf;
|
||||
sendbuf = (float *) malloc(sizeof(float));
|
||||
recvbuf = (float *) malloc(sizeof(float));
|
||||
// Note: make sure that the variable mosesWeights always holds the current decoder weights
|
||||
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
|
||||
cerr << "\nEpoch " << epoch << endl;
|
||||
numberCumulativeWeightsThisEpoch = 1;
|
||||
numberCumulativeWeightsThisEpoch = 0;
|
||||
summedApproxBleu = 0;
|
||||
// Sum up weights over one epoch, final average uses weights from last epoch
|
||||
if (!accumulateWeights) {
|
||||
cumulativeWeights.ZeroAll();
|
||||
cumulativeWeights = decoder->getWeights();
|
||||
}
|
||||
|
||||
// number of weight dumps this epoch
|
||||
@ -475,8 +474,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
// set weight for bleu feature to 0
|
||||
const vector<const ScoreProducer*> featureFunctions =
|
||||
StaticData::Instance().GetTranslationSystem(
|
||||
TranslationSystem::DEFAULT).GetFeatureFunctions();
|
||||
StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT).GetFeatureFunctions();
|
||||
mosesWeights.Assign(featureFunctions.back(), 0);
|
||||
|
||||
if (!hildreth && typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
@ -501,12 +499,11 @@ int main(int argc, char** argv) {
|
||||
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
|
||||
learning_rate, max_sentence_update, rank, updates_per_epoch);
|
||||
|
||||
// set decoder weights and accumulate weights
|
||||
if (controlUpdates && updateStatus < 0) {
|
||||
// TODO: could try to repeat hildreth with more slack
|
||||
cerr << "update ignored!" << endl;
|
||||
} else if (updates_per_epoch == -1) {
|
||||
// if updates are not accumulated, apply new weights now
|
||||
// apply new weights now
|
||||
if (normaliseWeights) {
|
||||
mosesWeights.L1Normalise();
|
||||
cerr << "\nRank " << rank << ", new weights (normalised): " << mosesWeights << endl;
|
||||
@ -524,15 +521,16 @@ int main(int argc, char** argv) {
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
}
|
||||
decoder->setWeights(averageWeights);
|
||||
|
||||
mosesWeights = averageWeights;
|
||||
cerr << "Rank " << rank << ", average weights: " << averageWeights << endl;
|
||||
}
|
||||
else {
|
||||
decoder->setWeights(mosesWeights);
|
||||
}
|
||||
|
||||
// set new Moses weights (averaged or not)
|
||||
decoder->setWeights(mosesWeights);
|
||||
|
||||
// compute difference to old weights
|
||||
ScoreComponentCollection weightDifference(decoder->getWeights());
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
|
||||
@ -582,7 +580,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
// apply accumulated updates
|
||||
if (makeUpdate && typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
ScoreComponentCollection mosesWeights = decoder->getWeights();
|
||||
mosesWeights = decoder->getWeights();
|
||||
ScoreComponentCollection accumulatedUpdates = ((MiraOptimiser*) optimiser)->getAccumulatedUpdates();
|
||||
cerr << "\nRank " << rank << ", updates to apply during epoch " << epoch << ": " << accumulatedUpdates << endl;
|
||||
mosesWeights.PlusEquals(accumulatedUpdates);
|
||||
@ -590,21 +588,34 @@ int main(int argc, char** argv) {
|
||||
|
||||
if (normaliseWeights) {
|
||||
mosesWeights.L1Normalise();
|
||||
cerr << "Rank " << rank << ", new weights (normalised) during epoch " << epoch << ": " << mosesWeights << endl;
|
||||
cerr << "Rank " << rank << ", weights (normalised) after applying cumulative update: " << mosesWeights << endl;
|
||||
} else {
|
||||
cerr << "Rank " << rank << ", new weights during epoch " << epoch << ": " << mosesWeights << endl;
|
||||
cerr << "Rank " << rank << ", weights after applying cumulative update: " << mosesWeights << endl;
|
||||
}
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
++numberCumulativeWeights;
|
||||
++numberCumulativeWeightsThisEpoch;
|
||||
|
||||
if (useAverageWeightsForPruning) {
|
||||
ScoreComponentCollection averageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
averageWeights.DivideEquals(numberCumulativeWeights);
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
}
|
||||
|
||||
mosesWeights = averageWeights;
|
||||
cerr << "Rank " << rank << ", average weights after applying cumulative update: " << averageWeights << endl;
|
||||
}
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
}
|
||||
|
||||
// mix weights?
|
||||
if (shardPosition % (shard.size() / mixingFrequency) == 0) {
|
||||
ScoreComponentCollection mixedWeights;
|
||||
#ifdef MPI_ENABLE
|
||||
ScoreComponentCollection mixedWeights;
|
||||
cerr << "\nRank " << rank << ", before mixing: " << mosesWeights << endl;
|
||||
|
||||
// collect all weights in mixedWeights and divide by number of processes
|
||||
@ -626,9 +637,10 @@ int main(int argc, char** argv) {
|
||||
// broadcast average weights from process 0
|
||||
mpi::broadcast(world, mixedWeights, 0);
|
||||
decoder->setWeights(mixedWeights);
|
||||
mosesWeights = mixedWeights;
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
mixedWeights = mosesWeights;
|
||||
cerr << "\nRank " << rank << ", no mixing, weights: " << mosesWeights << endl;
|
||||
#endif
|
||||
} // end mixing
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user