mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
introduce option to use average weights for pruning
git-svn-id: http://svn.statmt.org/repository/mira@3852 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
f6483df41c
commit
269f1018c3
185
mira/Main.cpp
185
mira/Main.cpp
@ -124,8 +124,8 @@ int main(int argc, char** argv) {
|
||||
bool print_feature_values;
|
||||
bool stop_dev_bleu;
|
||||
bool stop_approx_dev_bleu;
|
||||
bool update_after_epoch;
|
||||
size_t update_after;
|
||||
int updates_per_epoch;
|
||||
bool useAverageWeightsForPruning;
|
||||
po::options_description desc("Allowed options");
|
||||
desc.add_options()("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
|
||||
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(true), "Accumulate and average weights over all epochs")
|
||||
@ -140,6 +140,7 @@ int main(int argc, char** argv) {
|
||||
("decr-sentence-update", po::value<float>(&decrease_sentence_update)->default_value(0), "Decrease maximum weight update by the given value after every epoch")
|
||||
("dev-bleu", po::value<bool>(&devBleu)->default_value(true), "Compute BLEU score of oracle translations of the whole tuning set")
|
||||
("distinct-nbest", po::value<bool>(&distinctNbest)->default_value(false), "Use nbest list with distinct translations in inference step")
|
||||
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights, when using mpi")
|
||||
("epochs,e", po::value<size_t>(&epochs)->default_value(5), "Number of epochs")
|
||||
("help", po::value(&help)->zero_tokens()->default_value(false), "Print this help message and exit")
|
||||
("hildreth", po::value<bool>(&hildreth)->default_value(true), "Use Hildreth's optimisation algorithm")
|
||||
@ -171,8 +172,8 @@ int main(int argc, char** argv) {
|
||||
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases")
|
||||
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases")
|
||||
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
|
||||
("update-after-epoch", po::value<bool>(&update_after_epoch)->default_value(false), "Accumulate updates and apply them to the weight vector at the end of an epoch")
|
||||
("update-after", po::value<size_t>(&update_after)->default_value(1), "Accumulate updates of given number of sentences and apply them to the weight vector afterwards")
|
||||
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
|
||||
("use-average-weights-for-pruning", po::value<bool>(&useAverageWeightsForPruning)->default_value(true), "Use total weights (cumulative/weights changes) for pruning instead of current weights")
|
||||
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
|
||||
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
|
||||
("weighted-loss-function", po::value<bool>(&weightedLossFunction)->default_value(false), "Weight the loss of a hypothesis by its Bleu score")
|
||||
@ -275,6 +276,7 @@ int main(int argc, char** argv) {
|
||||
Optimiser* optimiser = NULL;
|
||||
cerr << "adapt-BP-factor: " << adapt_BPfactor << endl;
|
||||
cerr << "mix-frequency: " << mixingFrequency << endl;
|
||||
cerr << "weight-dump-frequency: " << weightDumpFrequency << endl;
|
||||
cerr << "weight-dump-stem: " << weightDumpStem << endl;
|
||||
cerr << "shuffle: " << shuffle << endl;
|
||||
cerr << "hildreth: " << hildreth << endl;
|
||||
@ -311,12 +313,8 @@ int main(int argc, char** argv) {
|
||||
cerr << "stop-dev-bleu: " << stop_dev_bleu << endl;
|
||||
cerr << "stop-approx-dev-bleu: " << stop_approx_dev_bleu << endl;
|
||||
cerr << "stop-weights: " << weightConvergence << endl;
|
||||
cerr << "update-after-epoch: " << update_after_epoch << endl;
|
||||
|
||||
if (update_after > 1) {
|
||||
// make sure the remaining changes of one epoch are applied
|
||||
update_after_epoch = true;
|
||||
}
|
||||
cerr << "updates-per-epoch: " << updates_per_epoch << endl;
|
||||
cerr << "use-total-weights-for-pruning: " << useAverageWeightsForPruning << endl;
|
||||
|
||||
if (learner == "mira") {
|
||||
cerr << "Optimising using Mira" << endl;
|
||||
@ -336,9 +334,12 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
//Main loop:
|
||||
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
|
||||
size_t weightChanges = 0;
|
||||
size_t weightChangesThisEpoch = 0;
|
||||
// print initial weights
|
||||
ScoreComponentCollection initialWeights = decoder->getWeights();
|
||||
cerr << "weights: " << initialWeights << endl;
|
||||
ScoreComponentCollection cumulativeWeights(initialWeights); // collect weights per epoch to produce an average
|
||||
size_t numberCumulativeWeights = 1;
|
||||
size_t numberCumulativeWeightsThisEpoch = 1;
|
||||
|
||||
time_t now = time(0); // get current time
|
||||
struct tm* tm = localtime(&now); // get struct filled out
|
||||
@ -346,14 +347,9 @@ int main(int argc, char** argv) {
|
||||
<< tm->tm_year + 1900 << ", " << tm->tm_hour << ":" << tm->tm_min << ":"
|
||||
<< tm->tm_sec << endl;
|
||||
|
||||
ScoreComponentCollection averageWeights;
|
||||
ScoreComponentCollection averageTotalWeights;
|
||||
ScoreComponentCollection averageTotalWeightsEnd;
|
||||
ScoreComponentCollection averageTotalWeightsPrevious;
|
||||
ScoreComponentCollection averageTotalWeightsBeforePrevious;
|
||||
|
||||
// print initial weights
|
||||
cerr << "weights: " << decoder->getWeights() << endl;
|
||||
ScoreComponentCollection mixedAverageWeights;
|
||||
ScoreComponentCollection mixedAverageWeightsPrevious;
|
||||
ScoreComponentCollection mixedAverageWeightsBeforePrevious;
|
||||
|
||||
float averageRatio = 0;
|
||||
float averageBleu = 0;
|
||||
@ -369,11 +365,12 @@ int main(int argc, char** argv) {
|
||||
recvbuf = (float *) malloc(sizeof(float));
|
||||
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
|
||||
cerr << "\nEpoch " << epoch << endl;
|
||||
weightChangesThisEpoch = 0;
|
||||
numberCumulativeWeightsThisEpoch = 1;
|
||||
summedApproxBleu = 0;
|
||||
// Sum up weights over one epoch, final average uses weights from last epoch
|
||||
if (!accumulateWeights) {
|
||||
cumulativeWeights.ZeroAll();
|
||||
cumulativeWeights = decoder->getWeights();
|
||||
}
|
||||
|
||||
// number of weight dumps this epoch
|
||||
@ -408,8 +405,7 @@ int main(int argc, char** argv) {
|
||||
input = inputSentences[*sid];
|
||||
const vector<string>& refs = referenceSentences[*sid];
|
||||
cerr << "Rank " << rank << ", batch position " << batchPosition << endl;
|
||||
cerr << "Rank " << rank << ", input sentence " << *sid << ": \""
|
||||
<< input << "\"" << endl;
|
||||
cerr << "Rank " << rank << ", input sentence " << *sid << ": \"" << input << "\"" << endl;
|
||||
|
||||
vector<ScoreComponentCollection> newFeatureValues;
|
||||
vector<float> newBleuScores;
|
||||
@ -417,46 +413,40 @@ int main(int argc, char** argv) {
|
||||
bleuScores.push_back(newBleuScores);
|
||||
|
||||
// MODEL
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest wrt model score"
|
||||
<< endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest wrt model score" << endl;
|
||||
vector<const Word*> bestModel = decoder->getNBest(input, *sid, n, 0.0,
|
||||
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
inputLengths.push_back(decoder->getCurrentInputLength());
|
||||
ref_ids.push_back(*sid);
|
||||
all_ref_ids.push_back(*sid);
|
||||
allBestModelScore.push_back(bestModel);
|
||||
decoder->cleanup();
|
||||
cerr << "Rank " << rank << ", model length: " << bestModel.size()
|
||||
<< " Bleu: " << bleuScores[batchPosition][0] << endl;
|
||||
cerr << "Rank " << rank << ", model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
|
||||
|
||||
// HOPE
|
||||
cerr << "Rank " << rank
|
||||
<< ", run decoder to get nbest hope translations" << endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest hope translations" << endl;
|
||||
size_t oraclePos = featureValues[batchPosition].size();
|
||||
oraclePositions.push_back(oraclePos);
|
||||
vector<const Word*> oracle = decoder->getNBest(input, *sid, n, 1.0,
|
||||
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
decoder->cleanup();
|
||||
oracles.push_back(oracle);
|
||||
cerr << "Rank " << rank << ", oracle length: " << oracle.size()
|
||||
<< " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
cerr << "Rank " << rank << ", oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
|
||||
oracleFeatureValues.push_back(featureValues[batchPosition][oraclePos]);
|
||||
float oracleBleuScore = bleuScores[batchPosition][oraclePos];
|
||||
oracleBleuScores.push_back(oracleBleuScore);
|
||||
|
||||
// FEAR
|
||||
cerr << "Rank " << rank
|
||||
<< ", run decoder to get nbest fear translations" << endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest fear translations" << endl;
|
||||
size_t fearPos = featureValues[batchPosition].size();
|
||||
vector<const Word*> fear = decoder->getNBest(input, *sid, n, -1.0, 1.0,
|
||||
featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
featureValues[batchPosition], bleuScores[batchPosition], true,
|
||||
distinctNbest, rank);
|
||||
decoder->cleanup();
|
||||
cerr << "Rank " << rank << ", fear length: " << fear.size()
|
||||
<< " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
|
||||
cerr << "Rank " << rank << ", fear length: " << fear.size() << " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
|
||||
|
||||
// for (size_t i = 0; i < bestModel.size(); ++i) {
|
||||
// delete bestModel[i];
|
||||
@ -509,13 +499,14 @@ int main(int argc, char** argv) {
|
||||
|
||||
int updateStatus = optimiser->updateWeights(mosesWeights, featureValues,
|
||||
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
|
||||
learning_rate, max_sentence_update, rank, update_after_epoch);
|
||||
learning_rate, max_sentence_update, rank, updates_per_epoch);
|
||||
|
||||
// set decoder weights and accumulate weights
|
||||
if (controlUpdates && updateStatus < 0) {
|
||||
// TODO: could try to repeat hildreth with more slack
|
||||
cerr << "update ignored!" << endl;
|
||||
} else if (!update_after_epoch) {
|
||||
} else if (updates_per_epoch == -1) {
|
||||
// if updates are not accumulated, apply new weights now
|
||||
if (normaliseWeights) {
|
||||
mosesWeights.L1Normalise();
|
||||
cerr << "\nRank " << rank << ", new weights (normalised): " << mosesWeights << endl;
|
||||
@ -523,14 +514,25 @@ int main(int argc, char** argv) {
|
||||
cerr << "\nRank " << rank << ", new weights: " << mosesWeights << endl;
|
||||
}
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
|
||||
++weightChanges;
|
||||
++weightChangesThisEpoch;
|
||||
++numberCumulativeWeights;
|
||||
++numberCumulativeWeightsThisEpoch;
|
||||
if (useAverageWeightsForPruning) {
|
||||
ScoreComponentCollection averageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
averageWeights.DivideEquals(numberCumulativeWeights);
|
||||
} else {
|
||||
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
}
|
||||
decoder->setWeights(averageWeights);
|
||||
cerr << "Rank " << rank << ", average weights: " << averageWeights << endl;
|
||||
}
|
||||
else {
|
||||
decoder->setWeights(mosesWeights);
|
||||
}
|
||||
|
||||
// compute difference to old weights
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
ScoreComponentCollection weightDifference(decoder->getWeights());
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
|
||||
@ -538,15 +540,15 @@ int main(int argc, char** argv) {
|
||||
if (actualBatchSize == 1) {
|
||||
cerr << "\nRank " << rank << ", nbest model score translations with new weights" << endl;
|
||||
vector<const Word*> bestModel = decoder->getNBest(input, *sid, n,
|
||||
0.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
|
||||
rank);
|
||||
0.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
|
||||
rank);
|
||||
decoder->cleanup();
|
||||
cerr << endl;
|
||||
|
||||
cerr << "\nRank " << rank << ", nbest hope translations with new weights" << endl;
|
||||
vector<const Word*> oracle = decoder->getNBest(input, *sid, n,
|
||||
1.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
|
||||
rank);
|
||||
1.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
|
||||
rank);
|
||||
decoder->cleanup();
|
||||
cerr << endl;
|
||||
}
|
||||
@ -576,8 +578,7 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
bool makeUpdate = ((update_after_epoch && sid == shard.end()) || ((update_after > 1) && (shardPosition % shard.size() == update_after)));
|
||||
bool mix = (shardPosition % (shard.size() / mixingFrequency) == 0);
|
||||
bool makeUpdate = updates_per_epoch == -1 ? 0 : (shardPosition % (shard.size() / updates_per_epoch) == 0);
|
||||
|
||||
// apply accumulated updates
|
||||
if (makeUpdate && typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
@ -596,64 +597,64 @@ int main(int argc, char** argv) {
|
||||
|
||||
decoder->setWeights(mosesWeights);
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
++weightChanges;
|
||||
++weightChangesThisEpoch;
|
||||
++numberCumulativeWeights;
|
||||
++numberCumulativeWeightsThisEpoch;
|
||||
}
|
||||
|
||||
// mix weights
|
||||
if (mix && !update_after_epoch) {
|
||||
ScoreComponentCollection averageWeights;
|
||||
// mix weights?
|
||||
if (shardPosition % (shard.size() / mixingFrequency) == 0) {
|
||||
ScoreComponentCollection mixedWeights;
|
||||
#ifdef MPI_ENABLE
|
||||
cerr << "\nRank " << rank << ", before mixing: " << mosesWeights << endl;
|
||||
|
||||
// collect all weights in averageWeights and divide by number of processes
|
||||
mpi::reduce(world, mosesWeights, averageWeights, SCCPlus(), 0);
|
||||
// collect all weights in mixedWeights and divide by number of processes
|
||||
mpi::reduce(world, mosesWeights, mixedWeights, SCCPlus(), 0);
|
||||
if (rank == 0) {
|
||||
// divide by number of processes
|
||||
averageWeights.DivideEquals(size);
|
||||
mixedWeights.DivideEquals(size);
|
||||
|
||||
// normalise weights after averaging
|
||||
if (normaliseWeights) {
|
||||
averageWeights.L1Normalise();
|
||||
cerr << "Average weights after mixing (normalised): " << averageWeights << endl;
|
||||
mixedWeights.L1Normalise();
|
||||
cerr << "Mixed weights (normalised): " << mixedWeights << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "Average weights after mixing: " << averageWeights << endl;
|
||||
cerr << "Mixed weights: " << mixedWeights << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// broadcast average weights from process 0
|
||||
mpi::broadcast(world, averageWeights, 0);
|
||||
decoder->setWeights(averageWeights);
|
||||
mpi::broadcast(world, mixedWeights, 0);
|
||||
decoder->setWeights(mixedWeights);
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
averageWeights = mosesWeights;
|
||||
mixedWeights = mosesWeights;
|
||||
#endif
|
||||
} // end mixing
|
||||
|
||||
// Average and dump weights of all processes over one or more epochs
|
||||
if ((mix && !update_after_epoch) || makeUpdate) {
|
||||
ScoreComponentCollection totalWeights(cumulativeWeights);
|
||||
// Dump weights?
|
||||
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
|
||||
ScoreComponentCollection tmpAverageWeights(cumulativeWeights);
|
||||
if (accumulateWeights) {
|
||||
totalWeights.DivideEquals(weightChanges);
|
||||
tmpAverageWeights.DivideEquals(numberCumulativeWeights);
|
||||
} else {
|
||||
totalWeights.DivideEquals(weightChangesThisEpoch);
|
||||
tmpAverageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
|
||||
}
|
||||
|
||||
#ifdef MPI_ENABLE
|
||||
// average across processes
|
||||
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
|
||||
mpi::reduce(world, tmpAverageWeights, mixedAverageWeights, SCCPlus(), 0);
|
||||
#endif
|
||||
#ifndef MPI_ENABLE
|
||||
averageTotalWeights = totalWeights;
|
||||
mixedAverageWeights = tmpAverageWeights;
|
||||
#endif
|
||||
if (rank == 0 && !weightDumpStem.empty()) {
|
||||
// divide by number of processes
|
||||
averageTotalWeights.DivideEquals(size);
|
||||
mixedAverageWeights.DivideEquals(size);
|
||||
|
||||
// normalise weights after averaging
|
||||
if (normaliseWeights) {
|
||||
averageTotalWeights.L1Normalise();
|
||||
mixedAverageWeights.L1Normalise();
|
||||
}
|
||||
|
||||
// dump final average weights
|
||||
@ -669,13 +670,13 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
if (accumulateWeights) {
|
||||
cerr << "\nAverage total weights (cumulative) during epoch " << epoch << ": " << averageTotalWeights << endl;
|
||||
cerr << "\nMixed average weights (cumulative) during epoch " << epoch << ": " << mixedAverageWeights << endl;
|
||||
} else {
|
||||
cerr << "\nAverage total weights during epoch " << epoch << ": " << averageTotalWeights << endl;
|
||||
cerr << "\nMixed average weights during epoch " << epoch << ": " << mixedAverageWeights << endl;
|
||||
}
|
||||
|
||||
cerr << "Dumping average total weights during epoch " << epoch << " to " << filename.str() << endl;
|
||||
averageTotalWeights.Save(filename.str());
|
||||
cerr << "Dumping mixed average weights during epoch " << epoch << " to " << filename.str() << endl;
|
||||
mixedAverageWeights.Save(filename.str());
|
||||
++weightEpochDump;
|
||||
}
|
||||
}// end averaging and dumping total weights
|
||||
@ -736,7 +737,7 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
// average approximate sentence bleu across processes
|
||||
sendbuf[0] = summedApproxBleu/weightChangesThisEpoch;
|
||||
sendbuf[0] = summedApproxBleu/numberCumulativeWeightsThisEpoch;
|
||||
recvbuf[0] = 0;
|
||||
MPI_Reduce(sendbuf, recvbuf, 1, MPI_FLOAT, MPI_SUM, 0, world);
|
||||
if (rank == 0) {
|
||||
@ -750,7 +751,7 @@ int main(int argc, char** argv) {
|
||||
#ifndef MPI_ENABLE
|
||||
averageBleu = bleu;
|
||||
cerr << "Average Bleu (dev) after epoch " << epoch << ": " << averageBleu << endl;
|
||||
averageApproxBleu = summedApproxBleu / weightChangesThisEpoch;
|
||||
averageApproxBleu = summedApproxBleu / numberCumulativeWeightsThisEpoch;
|
||||
cerr << "Average approx. sentence Bleu (dev) after epoch " << epoch << ": " << averageApproxBleu << endl;
|
||||
#endif
|
||||
if (rank == 0) {
|
||||
@ -792,13 +793,13 @@ int main(int argc, char** argv) {
|
||||
// Test if weights have converged
|
||||
if (weightConvergence) {
|
||||
bool reached = true;
|
||||
if (rank == 0) {
|
||||
ScoreComponentCollection firstDiff(averageTotalWeights);
|
||||
firstDiff.MinusEquals(averageTotalWeightsPrevious);
|
||||
if (rank == 0 && (epoch >= 2)) {
|
||||
ScoreComponentCollection firstDiff(mixedAverageWeights);
|
||||
firstDiff.MinusEquals(mixedAverageWeightsPrevious);
|
||||
cerr << "Average weight changes since previous epoch: " << firstDiff
|
||||
<< endl;
|
||||
ScoreComponentCollection secondDiff(averageTotalWeights);
|
||||
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
|
||||
ScoreComponentCollection secondDiff(mixedAverageWeights);
|
||||
secondDiff.MinusEquals(mixedAverageWeightsBeforePrevious);
|
||||
cerr << "Average weight changes since before previous epoch: "
|
||||
<< secondDiff << endl << endl;
|
||||
|
||||
@ -830,6 +831,12 @@ int main(int argc, char** argv) {
|
||||
dummy.Save(endfilename.str());
|
||||
}
|
||||
}
|
||||
|
||||
mixedAverageWeightsBeforePrevious = mixedAverageWeightsPrevious;
|
||||
mixedAverageWeightsPrevious = mixedAverageWeights;
|
||||
cerr << "mixed average weights: " << mixedAverageWeights << endl;
|
||||
cerr << "mixed average weights previous: " << mixedAverageWeightsPrevious << endl;
|
||||
cerr << "mixed average weights before previous: " << mixedAverageWeightsBeforePrevious << endl;
|
||||
#ifdef MPI_ENABLE
|
||||
mpi::broadcast(world, stop, 0);
|
||||
#endif
|
||||
|
@ -13,7 +13,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
ScoreComponentCollection>& oracleFeatureValues,
|
||||
const vector<float> oracleBleuScores, const vector<size_t> sentenceIds,
|
||||
float learning_rate, float max_sentence_update, size_t rank,
|
||||
bool update_after_epoch) {
|
||||
int updates_per_epoch) {
|
||||
|
||||
// add every oracle in batch to list of oracles (under certain conditions)
|
||||
for (size_t i = 0; i < oracleFeatureValues.size(); ++i) {
|
||||
@ -52,8 +52,9 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
}
|
||||
}
|
||||
|
||||
size_t violatedConstraintsBefore = 0;
|
||||
// vector of feature values differences for all created constraints
|
||||
vector<ScoreComponentCollection> featureValueDiffs;
|
||||
size_t violatedConstraintsBefore = 0;
|
||||
vector<float> lossMarginDistances;
|
||||
|
||||
// find most violated constraint
|
||||
@ -61,11 +62,12 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
ScoreComponentCollection maxViolationfeatureValueDiff;
|
||||
|
||||
float oldDistanceFromOptimum = 0;
|
||||
// iterate over input sentences (1 (online) or more (batch))
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
size_t sentenceId = sentenceIds[i];
|
||||
if (m_oracles[sentenceId].size() > 1)
|
||||
cerr << "Available oracles for source sentence " << sentenceId << ": "
|
||||
<< m_oracles[sentenceId].size() << endl;
|
||||
cerr << "Available oracles for source sentence " << sentenceId << ": " << m_oracles[sentenceId].size() << endl;
|
||||
// iterate over hypothesis translations for one input sentence
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
// check if optimisation criterion is violated for one hypothesis and the oracle
|
||||
// h(e*) >= h(e_ij) + loss(e_ij)
|
||||
@ -141,7 +143,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
|
||||
// run optimisation: compute alphas for all given constraints
|
||||
vector<float> alphas;
|
||||
ScoreComponentCollection totalUpdate;
|
||||
ScoreComponentCollection summedUpdate;
|
||||
if (m_accumulateMostViolatedConstraints && !m_pastAndCurrentConstraints) {
|
||||
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
|
||||
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
|
||||
@ -156,12 +158,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
// Update the weight vector according to the alphas and the feature value differences
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
for (size_t k = 0; k < m_featureValueDiffs.size(); ++k) {
|
||||
// compute update
|
||||
float update = alphas[k];
|
||||
m_featureValueDiffs[k].MultiplyEquals(update);
|
||||
float alpha = alphas[k];
|
||||
m_featureValueDiffs[k].MultiplyEquals(alpha);
|
||||
|
||||
// accumulate update
|
||||
totalUpdate.PlusEquals(m_featureValueDiffs[k]);
|
||||
// sum up update
|
||||
summedUpdate.PlusEquals(m_featureValueDiffs[k]);
|
||||
}
|
||||
} else if (violatedConstraintsBefore > 0) {
|
||||
if (m_pastAndCurrentConstraints) {
|
||||
@ -177,22 +178,20 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
}
|
||||
|
||||
if (m_slack != 0) {
|
||||
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances,
|
||||
m_slack);
|
||||
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances, m_slack);
|
||||
} else {
|
||||
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances);
|
||||
}
|
||||
|
||||
// Update the weight vector according to the alphas and the feature value differences
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
// * w' = w' + SUM alpha_i * (h_i(oracle) - h_i(hypothesis))
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
// compute update
|
||||
float alpha = alphas[k];
|
||||
featureValueDiffs[k].MultiplyEquals(alpha);
|
||||
|
||||
// accumulate update
|
||||
totalUpdate.PlusEquals(featureValueDiffs[k]);
|
||||
}
|
||||
// sum up update
|
||||
summedUpdate.PlusEquals(featureValueDiffs[k]);
|
||||
}
|
||||
} else {
|
||||
cerr << "No constraint violated for this batch" << endl;
|
||||
return 0;
|
||||
@ -200,34 +199,26 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
|
||||
// apply learning rate (fixed or flexible)
|
||||
if (learning_rate != 1) {
|
||||
cerr << "Rank " << rank << ", update before applying learning rate: "
|
||||
<< totalUpdate << endl;
|
||||
totalUpdate.MultiplyEquals(learning_rate);
|
||||
cerr << "Rank " << rank << ", update after applying learning rate: "
|
||||
<< totalUpdate << endl;
|
||||
cerr << "Rank " << rank << ", update before applying learning rate: " << summedUpdate << endl;
|
||||
summedUpdate.MultiplyEquals(learning_rate);
|
||||
cerr << "Rank " << rank << ", update after applying learning rate: " << summedUpdate << endl;
|
||||
}
|
||||
|
||||
// apply threshold scaling
|
||||
if (max_sentence_update != -1) {
|
||||
cerr << "Rank " << rank
|
||||
<< ", update before scaling to max-sentence-update: " << totalUpdate
|
||||
<< endl;
|
||||
totalUpdate.ThresholdScaling(max_sentence_update);
|
||||
cerr << "Rank " << rank
|
||||
<< ", update after scaling to max-sentence-update: " << totalUpdate
|
||||
<< endl;
|
||||
cerr << "Rank " << rank << ", update before scaling to max-sentence-update: " << summedUpdate << endl;
|
||||
summedUpdate.ThresholdScaling(max_sentence_update);
|
||||
cerr << "Rank " << rank << ", update after scaling to max-sentence-update: " << summedUpdate << endl;
|
||||
}
|
||||
|
||||
if (update_after_epoch) {
|
||||
m_accumulatedUpdates.PlusEquals(totalUpdate);
|
||||
if (updates_per_epoch > 0) {
|
||||
m_accumulatedUpdates.PlusEquals(summedUpdate);
|
||||
cerr << "Rank " << rank << ", new accumulated updates:" << m_accumulatedUpdates << endl;
|
||||
} else {
|
||||
// apply update to weight vector
|
||||
cerr << "Rank " << rank << ", weights before update: " << currWeights
|
||||
<< endl;
|
||||
currWeights.PlusEquals(totalUpdate);
|
||||
cerr << "Rank " << rank << ", weights after update: " << currWeights
|
||||
<< endl;
|
||||
cerr << "Rank " << rank << ", weights before update: " << currWeights << endl;
|
||||
currWeights.PlusEquals(summedUpdate);
|
||||
cerr << "Rank " << rank << ", weights after update: " << currWeights << endl;
|
||||
|
||||
// sanity check: how many constraints violated after optimisation?
|
||||
size_t violatedConstraintsAfter = 0;
|
||||
|
@ -39,7 +39,7 @@ namespace Mira {
|
||||
float learning_rate,
|
||||
float max_sentence_update,
|
||||
size_t rank,
|
||||
bool update_after_epoch) = 0;
|
||||
int updates_per_epoch) = 0;
|
||||
};
|
||||
|
||||
class Perceptron : public Optimiser {
|
||||
@ -54,7 +54,7 @@ namespace Mira {
|
||||
float learning_rate,
|
||||
float max_sentence_update,
|
||||
size_t rank,
|
||||
bool update_after_epoch);
|
||||
int updates_per_epoch);
|
||||
};
|
||||
|
||||
class MiraOptimiser : public Optimiser {
|
||||
@ -88,7 +88,7 @@ namespace Mira {
|
||||
float learning_rate,
|
||||
float max_sentence_update,
|
||||
size_t rank,
|
||||
bool update_after_epoch);
|
||||
int updates_per_epoch);
|
||||
|
||||
void setOracleIndices(std::vector<size_t> oracleIndices) {
|
||||
m_oracleIndices= oracleIndices;
|
||||
|
@ -34,7 +34,7 @@ int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
|
||||
float learning_rate,
|
||||
float max_sentence_update,
|
||||
size_t rank,
|
||||
bool update_after_epoch)
|
||||
int updates_per_epoch)
|
||||
{
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
|
@ -181,6 +181,10 @@ namespace Moses {
|
||||
/*if (i->first != DEFAULT_NAME) {
|
||||
out << value << ", ";
|
||||
}*/
|
||||
if (i->first != DEFAULT_NAME) {
|
||||
out << value << ", ";
|
||||
//out << i->first << "=" << value << ", ";
|
||||
}
|
||||
}
|
||||
out << "}";
|
||||
return out;
|
||||
|
Loading…
Reference in New Issue
Block a user