introduce option to use average weights for pruning

git-svn-id: http://svn.statmt.org/repository/mira@3852 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-03-28 18:11:45 +00:00 committed by Ondrej Bojar
parent f6483df41c
commit 269f1018c3
5 changed files with 131 additions and 129 deletions

View File

@ -124,8 +124,8 @@ int main(int argc, char** argv) {
bool print_feature_values;
bool stop_dev_bleu;
bool stop_approx_dev_bleu;
bool update_after_epoch;
size_t update_after;
int updates_per_epoch;
bool useAverageWeightsForPruning;
po::options_description desc("Allowed options");
desc.add_options()("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(true), "Accumulate and average weights over all epochs")
@ -140,6 +140,7 @@ int main(int argc, char** argv) {
("decr-sentence-update", po::value<float>(&decrease_sentence_update)->default_value(0), "Decrease maximum weight update by the given value after every epoch")
("dev-bleu", po::value<bool>(&devBleu)->default_value(true), "Compute BLEU score of oracle translations of the whole tuning set")
("distinct-nbest", po::value<bool>(&distinctNbest)->default_value(false), "Use nbest list with distinct translations in inference step")
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights, when using mpi")
("epochs,e", po::value<size_t>(&epochs)->default_value(5), "Number of epochs")
("help", po::value(&help)->zero_tokens()->default_value(false), "Print this help message and exit")
("hildreth", po::value<bool>(&hildreth)->default_value(true), "Use Hildreth's optimisation algorithm")
@ -171,8 +172,8 @@ int main(int argc, char** argv) {
("stop-dev-bleu", po::value<bool>(&stop_dev_bleu)->default_value(false), "Stop when average Bleu (dev) decreases")
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases")
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
("update-after-epoch", po::value<bool>(&update_after_epoch)->default_value(false), "Accumulate updates and apply them to the weight vector at the end of an epoch")
("update-after", po::value<size_t>(&update_after)->default_value(1), "Accumulate updates of given number of sentences and apply them to the weight vector afterwards")
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
("use-average-weights-for-pruning", po::value<bool>(&useAverageWeightsForPruning)->default_value(true), "Use total weights (cumulative/weights changes) for pruning instead of current weights")
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
("weighted-loss-function", po::value<bool>(&weightedLossFunction)->default_value(false), "Weight the loss of a hypothesis by its Bleu score")
@ -275,6 +276,7 @@ int main(int argc, char** argv) {
Optimiser* optimiser = NULL;
cerr << "adapt-BP-factor: " << adapt_BPfactor << endl;
cerr << "mix-frequency: " << mixingFrequency << endl;
cerr << "weight-dump-frequency: " << weightDumpFrequency << endl;
cerr << "weight-dump-stem: " << weightDumpStem << endl;
cerr << "shuffle: " << shuffle << endl;
cerr << "hildreth: " << hildreth << endl;
@ -311,12 +313,8 @@ int main(int argc, char** argv) {
cerr << "stop-dev-bleu: " << stop_dev_bleu << endl;
cerr << "stop-approx-dev-bleu: " << stop_approx_dev_bleu << endl;
cerr << "stop-weights: " << weightConvergence << endl;
cerr << "update-after-epoch: " << update_after_epoch << endl;
if (update_after > 1) {
// make sure the remaining changes of one epoch are applied
update_after_epoch = true;
}
cerr << "updates-per-epoch: " << updates_per_epoch << endl;
cerr << "use-total-weights-for-pruning: " << useAverageWeightsForPruning << endl;
if (learner == "mira") {
cerr << "Optimising using Mira" << endl;
@ -336,9 +334,12 @@ int main(int argc, char** argv) {
}
//Main loop:
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
size_t weightChanges = 0;
size_t weightChangesThisEpoch = 0;
// print initial weights
ScoreComponentCollection initialWeights = decoder->getWeights();
cerr << "weights: " << initialWeights << endl;
ScoreComponentCollection cumulativeWeights(initialWeights); // collect weights per epoch to produce an average
size_t numberCumulativeWeights = 1;
size_t numberCumulativeWeightsThisEpoch = 1;
time_t now = time(0); // get current time
struct tm* tm = localtime(&now); // get struct filled out
@ -346,14 +347,9 @@ int main(int argc, char** argv) {
<< tm->tm_year + 1900 << ", " << tm->tm_hour << ":" << tm->tm_min << ":"
<< tm->tm_sec << endl;
ScoreComponentCollection averageWeights;
ScoreComponentCollection averageTotalWeights;
ScoreComponentCollection averageTotalWeightsEnd;
ScoreComponentCollection averageTotalWeightsPrevious;
ScoreComponentCollection averageTotalWeightsBeforePrevious;
// print initial weights
cerr << "weights: " << decoder->getWeights() << endl;
ScoreComponentCollection mixedAverageWeights;
ScoreComponentCollection mixedAverageWeightsPrevious;
ScoreComponentCollection mixedAverageWeightsBeforePrevious;
float averageRatio = 0;
float averageBleu = 0;
@ -369,11 +365,12 @@ int main(int argc, char** argv) {
recvbuf = (float *) malloc(sizeof(float));
for (size_t epoch = 0; epoch < epochs && !stop; ++epoch) {
cerr << "\nEpoch " << epoch << endl;
weightChangesThisEpoch = 0;
numberCumulativeWeightsThisEpoch = 1;
summedApproxBleu = 0;
// Sum up weights over one epoch, final average uses weights from last epoch
if (!accumulateWeights) {
cumulativeWeights.ZeroAll();
cumulativeWeights = decoder->getWeights();
}
// number of weight dumps this epoch
@ -408,8 +405,7 @@ int main(int argc, char** argv) {
input = inputSentences[*sid];
const vector<string>& refs = referenceSentences[*sid];
cerr << "Rank " << rank << ", batch position " << batchPosition << endl;
cerr << "Rank " << rank << ", input sentence " << *sid << ": \""
<< input << "\"" << endl;
cerr << "Rank " << rank << ", input sentence " << *sid << ": \"" << input << "\"" << endl;
vector<ScoreComponentCollection> newFeatureValues;
vector<float> newBleuScores;
@ -417,46 +413,40 @@ int main(int argc, char** argv) {
bleuScores.push_back(newBleuScores);
// MODEL
cerr << "Rank " << rank << ", run decoder to get nbest wrt model score"
<< endl;
cerr << "Rank " << rank << ", run decoder to get nbest wrt model score" << endl;
vector<const Word*> bestModel = decoder->getNBest(input, *sid, n, 0.0,
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
inputLengths.push_back(decoder->getCurrentInputLength());
ref_ids.push_back(*sid);
all_ref_ids.push_back(*sid);
allBestModelScore.push_back(bestModel);
decoder->cleanup();
cerr << "Rank " << rank << ", model length: " << bestModel.size()
<< " Bleu: " << bleuScores[batchPosition][0] << endl;
cerr << "Rank " << rank << ", model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
// HOPE
cerr << "Rank " << rank
<< ", run decoder to get nbest hope translations" << endl;
cerr << "Rank " << rank << ", run decoder to get nbest hope translations" << endl;
size_t oraclePos = featureValues[batchPosition].size();
oraclePositions.push_back(oraclePos);
vector<const Word*> oracle = decoder->getNBest(input, *sid, n, 1.0,
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
1.0, featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
decoder->cleanup();
oracles.push_back(oracle);
cerr << "Rank " << rank << ", oracle length: " << oracle.size()
<< " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
cerr << "Rank " << rank << ", oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
oracleFeatureValues.push_back(featureValues[batchPosition][oraclePos]);
float oracleBleuScore = bleuScores[batchPosition][oraclePos];
oracleBleuScores.push_back(oracleBleuScore);
// FEAR
cerr << "Rank " << rank
<< ", run decoder to get nbest fear translations" << endl;
cerr << "Rank " << rank << ", run decoder to get nbest fear translations" << endl;
size_t fearPos = featureValues[batchPosition].size();
vector<const Word*> fear = decoder->getNBest(input, *sid, n, -1.0, 1.0,
featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
featureValues[batchPosition], bleuScores[batchPosition], true,
distinctNbest, rank);
decoder->cleanup();
cerr << "Rank " << rank << ", fear length: " << fear.size()
<< " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
cerr << "Rank " << rank << ", fear length: " << fear.size() << " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
// for (size_t i = 0; i < bestModel.size(); ++i) {
// delete bestModel[i];
@ -509,13 +499,14 @@ int main(int argc, char** argv) {
int updateStatus = optimiser->updateWeights(mosesWeights, featureValues,
losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids,
learning_rate, max_sentence_update, rank, update_after_epoch);
learning_rate, max_sentence_update, rank, updates_per_epoch);
// set decoder weights and accumulate weights
if (controlUpdates && updateStatus < 0) {
// TODO: could try to repeat hildreth with more slack
cerr << "update ignored!" << endl;
} else if (!update_after_epoch) {
} else if (updates_per_epoch == -1) {
// if updates are not accumulated, apply new weights now
if (normaliseWeights) {
mosesWeights.L1Normalise();
cerr << "\nRank " << rank << ", new weights (normalised): " << mosesWeights << endl;
@ -523,14 +514,25 @@ int main(int argc, char** argv) {
cerr << "\nRank " << rank << ", new weights: " << mosesWeights << endl;
}
decoder->setWeights(mosesWeights);
cumulativeWeights.PlusEquals(mosesWeights);
++weightChanges;
++weightChangesThisEpoch;
++numberCumulativeWeights;
++numberCumulativeWeightsThisEpoch;
if (useAverageWeightsForPruning) {
ScoreComponentCollection averageWeights(cumulativeWeights);
if (accumulateWeights) {
averageWeights.DivideEquals(numberCumulativeWeights);
} else {
averageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
}
decoder->setWeights(averageWeights);
cerr << "Rank " << rank << ", average weights: " << averageWeights << endl;
}
else {
decoder->setWeights(mosesWeights);
}
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
ScoreComponentCollection weightDifference(decoder->getWeights());
weightDifference.MinusEquals(oldWeights);
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
@ -538,15 +540,15 @@ int main(int argc, char** argv) {
if (actualBatchSize == 1) {
cerr << "\nRank " << rank << ", nbest model score translations with new weights" << endl;
vector<const Word*> bestModel = decoder->getNBest(input, *sid, n,
0.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
rank);
0.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
rank);
decoder->cleanup();
cerr << endl;
cerr << "\nRank " << rank << ", nbest hope translations with new weights" << endl;
vector<const Word*> oracle = decoder->getNBest(input, *sid, n,
1.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
rank);
1.0, 1.0, featureValues[0], bleuScores[0], true, distinctNbest,
rank);
decoder->cleanup();
cerr << endl;
}
@ -576,8 +578,7 @@ int main(int argc, char** argv) {
}
}
bool makeUpdate = ((update_after_epoch && sid == shard.end()) || ((update_after > 1) && (shardPosition % shard.size() == update_after)));
bool mix = (shardPosition % (shard.size() / mixingFrequency) == 0);
bool makeUpdate = updates_per_epoch == -1 ? 0 : (shardPosition % (shard.size() / updates_per_epoch) == 0);
// apply accumulated updates
if (makeUpdate && typeid(*optimiser) == typeid(MiraOptimiser)) {
@ -596,64 +597,64 @@ int main(int argc, char** argv) {
decoder->setWeights(mosesWeights);
cumulativeWeights.PlusEquals(mosesWeights);
++weightChanges;
++weightChangesThisEpoch;
++numberCumulativeWeights;
++numberCumulativeWeightsThisEpoch;
}
// mix weights
if (mix && !update_after_epoch) {
ScoreComponentCollection averageWeights;
// mix weights?
if (shardPosition % (shard.size() / mixingFrequency) == 0) {
ScoreComponentCollection mixedWeights;
#ifdef MPI_ENABLE
cerr << "\nRank " << rank << ", before mixing: " << mosesWeights << endl;
// collect all weights in averageWeights and divide by number of processes
mpi::reduce(world, mosesWeights, averageWeights, SCCPlus(), 0);
// collect all weights in mixedWeights and divide by number of processes
mpi::reduce(world, mosesWeights, mixedWeights, SCCPlus(), 0);
if (rank == 0) {
// divide by number of processes
averageWeights.DivideEquals(size);
mixedWeights.DivideEquals(size);
// normalise weights after averaging
if (normaliseWeights) {
averageWeights.L1Normalise();
cerr << "Average weights after mixing (normalised): " << averageWeights << endl;
mixedWeights.L1Normalise();
cerr << "Mixed weights (normalised): " << mixedWeights << endl;
}
else {
cerr << "Average weights after mixing: " << averageWeights << endl;
cerr << "Mixed weights: " << mixedWeights << endl;
}
}
// broadcast average weights from process 0
mpi::broadcast(world, averageWeights, 0);
decoder->setWeights(averageWeights);
mpi::broadcast(world, mixedWeights, 0);
decoder->setWeights(mixedWeights);
#endif
#ifndef MPI_ENABLE
averageWeights = mosesWeights;
mixedWeights = mosesWeights;
#endif
} // end mixing
// Average and dump weights of all processes over one or more epochs
if ((mix && !update_after_epoch) || makeUpdate) {
ScoreComponentCollection totalWeights(cumulativeWeights);
// Dump weights?
if (shardPosition % (shard.size() / weightDumpFrequency) == 0) {
ScoreComponentCollection tmpAverageWeights(cumulativeWeights);
if (accumulateWeights) {
totalWeights.DivideEquals(weightChanges);
tmpAverageWeights.DivideEquals(numberCumulativeWeights);
} else {
totalWeights.DivideEquals(weightChangesThisEpoch);
tmpAverageWeights.DivideEquals(numberCumulativeWeightsThisEpoch);
}
#ifdef MPI_ENABLE
// average across processes
mpi::reduce(world, totalWeights, averageTotalWeights, SCCPlus(), 0);
mpi::reduce(world, tmpAverageWeights, mixedAverageWeights, SCCPlus(), 0);
#endif
#ifndef MPI_ENABLE
averageTotalWeights = totalWeights;
mixedAverageWeights = tmpAverageWeights;
#endif
if (rank == 0 && !weightDumpStem.empty()) {
// divide by number of processes
averageTotalWeights.DivideEquals(size);
mixedAverageWeights.DivideEquals(size);
// normalise weights after averaging
if (normaliseWeights) {
averageTotalWeights.L1Normalise();
mixedAverageWeights.L1Normalise();
}
// dump final average weights
@ -669,13 +670,13 @@ int main(int argc, char** argv) {
}
if (accumulateWeights) {
cerr << "\nAverage total weights (cumulative) during epoch " << epoch << ": " << averageTotalWeights << endl;
cerr << "\nMixed average weights (cumulative) during epoch " << epoch << ": " << mixedAverageWeights << endl;
} else {
cerr << "\nAverage total weights during epoch " << epoch << ": " << averageTotalWeights << endl;
cerr << "\nMixed average weights during epoch " << epoch << ": " << mixedAverageWeights << endl;
}
cerr << "Dumping average total weights during epoch " << epoch << " to " << filename.str() << endl;
averageTotalWeights.Save(filename.str());
cerr << "Dumping mixed average weights during epoch " << epoch << " to " << filename.str() << endl;
mixedAverageWeights.Save(filename.str());
++weightEpochDump;
}
}// end averaging and dumping total weights
@ -736,7 +737,7 @@ int main(int argc, char** argv) {
}
// average approximate sentence bleu across processes
sendbuf[0] = summedApproxBleu/weightChangesThisEpoch;
sendbuf[0] = summedApproxBleu/numberCumulativeWeightsThisEpoch;
recvbuf[0] = 0;
MPI_Reduce(sendbuf, recvbuf, 1, MPI_FLOAT, MPI_SUM, 0, world);
if (rank == 0) {
@ -750,7 +751,7 @@ int main(int argc, char** argv) {
#ifndef MPI_ENABLE
averageBleu = bleu;
cerr << "Average Bleu (dev) after epoch " << epoch << ": " << averageBleu << endl;
averageApproxBleu = summedApproxBleu / weightChangesThisEpoch;
averageApproxBleu = summedApproxBleu / numberCumulativeWeightsThisEpoch;
cerr << "Average approx. sentence Bleu (dev) after epoch " << epoch << ": " << averageApproxBleu << endl;
#endif
if (rank == 0) {
@ -792,13 +793,13 @@ int main(int argc, char** argv) {
// Test if weights have converged
if (weightConvergence) {
bool reached = true;
if (rank == 0) {
ScoreComponentCollection firstDiff(averageTotalWeights);
firstDiff.MinusEquals(averageTotalWeightsPrevious);
if (rank == 0 && (epoch >= 2)) {
ScoreComponentCollection firstDiff(mixedAverageWeights);
firstDiff.MinusEquals(mixedAverageWeightsPrevious);
cerr << "Average weight changes since previous epoch: " << firstDiff
<< endl;
ScoreComponentCollection secondDiff(averageTotalWeights);
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
ScoreComponentCollection secondDiff(mixedAverageWeights);
secondDiff.MinusEquals(mixedAverageWeightsBeforePrevious);
cerr << "Average weight changes since before previous epoch: "
<< secondDiff << endl << endl;
@ -830,6 +831,12 @@ int main(int argc, char** argv) {
dummy.Save(endfilename.str());
}
}
mixedAverageWeightsBeforePrevious = mixedAverageWeightsPrevious;
mixedAverageWeightsPrevious = mixedAverageWeights;
cerr << "mixed average weights: " << mixedAverageWeights << endl;
cerr << "mixed average weights previous: " << mixedAverageWeightsPrevious << endl;
cerr << "mixed average weights before previous: " << mixedAverageWeightsBeforePrevious << endl;
#ifdef MPI_ENABLE
mpi::broadcast(world, stop, 0);
#endif

View File

@ -13,7 +13,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
ScoreComponentCollection>& oracleFeatureValues,
const vector<float> oracleBleuScores, const vector<size_t> sentenceIds,
float learning_rate, float max_sentence_update, size_t rank,
bool update_after_epoch) {
int updates_per_epoch) {
// add every oracle in batch to list of oracles (under certain conditions)
for (size_t i = 0; i < oracleFeatureValues.size(); ++i) {
@ -52,8 +52,9 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
}
}
size_t violatedConstraintsBefore = 0;
// vector of feature values differences for all created constraints
vector<ScoreComponentCollection> featureValueDiffs;
size_t violatedConstraintsBefore = 0;
vector<float> lossMarginDistances;
// find most violated constraint
@ -61,11 +62,12 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
ScoreComponentCollection maxViolationfeatureValueDiff;
float oldDistanceFromOptimum = 0;
// iterate over input sentences (1 (online) or more (batch))
for (size_t i = 0; i < featureValues.size(); ++i) {
size_t sentenceId = sentenceIds[i];
if (m_oracles[sentenceId].size() > 1)
cerr << "Available oracles for source sentence " << sentenceId << ": "
<< m_oracles[sentenceId].size() << endl;
cerr << "Available oracles for source sentence " << sentenceId << ": " << m_oracles[sentenceId].size() << endl;
// iterate over hypothesis translations for one input sentence
for (size_t j = 0; j < featureValues[i].size(); ++j) {
// check if optimisation criterion is violated for one hypothesis and the oracle
// h(e*) >= h(e_ij) + loss(e_ij)
@ -141,7 +143,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
// run optimisation: compute alphas for all given constraints
vector<float> alphas;
ScoreComponentCollection totalUpdate;
ScoreComponentCollection summedUpdate;
if (m_accumulateMostViolatedConstraints && !m_pastAndCurrentConstraints) {
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
@ -156,12 +158,11 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
// Update the weight vector according to the alphas and the feature value differences
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
for (size_t k = 0; k < m_featureValueDiffs.size(); ++k) {
// compute update
float update = alphas[k];
m_featureValueDiffs[k].MultiplyEquals(update);
float alpha = alphas[k];
m_featureValueDiffs[k].MultiplyEquals(alpha);
// accumulate update
totalUpdate.PlusEquals(m_featureValueDiffs[k]);
// sum up update
summedUpdate.PlusEquals(m_featureValueDiffs[k]);
}
} else if (violatedConstraintsBefore > 0) {
if (m_pastAndCurrentConstraints) {
@ -177,22 +178,20 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
}
if (m_slack != 0) {
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances,
m_slack);
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances, m_slack);
} else {
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances);
}
// Update the weight vector according to the alphas and the feature value differences
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
// * w' = w' + SUM alpha_i * (h_i(oracle) - h_i(hypothesis))
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
// compute update
float alpha = alphas[k];
featureValueDiffs[k].MultiplyEquals(alpha);
// accumulate update
totalUpdate.PlusEquals(featureValueDiffs[k]);
}
// sum up update
summedUpdate.PlusEquals(featureValueDiffs[k]);
}
} else {
cerr << "No constraint violated for this batch" << endl;
return 0;
@ -200,34 +199,26 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
// apply learning rate (fixed or flexible)
if (learning_rate != 1) {
cerr << "Rank " << rank << ", update before applying learning rate: "
<< totalUpdate << endl;
totalUpdate.MultiplyEquals(learning_rate);
cerr << "Rank " << rank << ", update after applying learning rate: "
<< totalUpdate << endl;
cerr << "Rank " << rank << ", update before applying learning rate: " << summedUpdate << endl;
summedUpdate.MultiplyEquals(learning_rate);
cerr << "Rank " << rank << ", update after applying learning rate: " << summedUpdate << endl;
}
// apply threshold scaling
if (max_sentence_update != -1) {
cerr << "Rank " << rank
<< ", update before scaling to max-sentence-update: " << totalUpdate
<< endl;
totalUpdate.ThresholdScaling(max_sentence_update);
cerr << "Rank " << rank
<< ", update after scaling to max-sentence-update: " << totalUpdate
<< endl;
cerr << "Rank " << rank << ", update before scaling to max-sentence-update: " << summedUpdate << endl;
summedUpdate.ThresholdScaling(max_sentence_update);
cerr << "Rank " << rank << ", update after scaling to max-sentence-update: " << summedUpdate << endl;
}
if (update_after_epoch) {
m_accumulatedUpdates.PlusEquals(totalUpdate);
if (updates_per_epoch > 0) {
m_accumulatedUpdates.PlusEquals(summedUpdate);
cerr << "Rank " << rank << ", new accumulated updates:" << m_accumulatedUpdates << endl;
} else {
// apply update to weight vector
cerr << "Rank " << rank << ", weights before update: " << currWeights
<< endl;
currWeights.PlusEquals(totalUpdate);
cerr << "Rank " << rank << ", weights after update: " << currWeights
<< endl;
cerr << "Rank " << rank << ", weights before update: " << currWeights << endl;
currWeights.PlusEquals(summedUpdate);
cerr << "Rank " << rank << ", weights after update: " << currWeights << endl;
// sanity check: how many constraints violated after optimisation?
size_t violatedConstraintsAfter = 0;

View File

@ -39,7 +39,7 @@ namespace Mira {
float learning_rate,
float max_sentence_update,
size_t rank,
bool update_after_epoch) = 0;
int updates_per_epoch) = 0;
};
class Perceptron : public Optimiser {
@ -54,7 +54,7 @@ namespace Mira {
float learning_rate,
float max_sentence_update,
size_t rank,
bool update_after_epoch);
int updates_per_epoch);
};
class MiraOptimiser : public Optimiser {
@ -88,7 +88,7 @@ namespace Mira {
float learning_rate,
float max_sentence_update,
size_t rank,
bool update_after_epoch);
int updates_per_epoch);
void setOracleIndices(std::vector<size_t> oracleIndices) {
m_oracleIndices= oracleIndices;

View File

@ -34,7 +34,7 @@ int Perceptron::updateWeights(ScoreComponentCollection& currWeights,
float learning_rate,
float max_sentence_update,
size_t rank,
bool update_after_epoch)
int updates_per_epoch)
{
for (size_t i = 0; i < featureValues.size(); ++i) {
for (size_t j = 0; j < featureValues[i].size(); ++j) {

View File

@ -181,6 +181,10 @@ namespace Moses {
/*if (i->first != DEFAULT_NAME) {
out << value << ", ";
}*/
if (i->first != DEFAULT_NAME) {
out << value << ", ";
//out << i->first << "=" << value << ", ";
}
}
out << "}";
return out;