mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
reintroduce clipping
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3772 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
da95cb2560
commit
fbc8910102
@ -131,7 +131,7 @@ int main(int argc, char** argv) {
|
||||
("past-and-current-constraints", po::value<bool>(&pastAndCurrentConstraints)->default_value(false), "Accumulate most violated constraint per example and use them along all current constraints")
|
||||
("suppress-convergence", po::value<bool>(&suppressConvergence)->default_value(false), "Suppress convergence, fixed number of epochs")
|
||||
("clipping", po::value<float>(&clipping)->default_value(0.01f), "Set a threshold to regularise updates")
|
||||
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold with SMO (instead of adaptive)");
|
||||
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold");
|
||||
|
||||
po::options_description cmdline_options;
|
||||
cmdline_options.add(desc);
|
||||
@ -228,6 +228,8 @@ int main(int argc, char** argv) {
|
||||
cerr << "Add only violated constraints? " << onlyViolatedConstraints << endl;
|
||||
cerr << "Using slack? " << slack << endl;
|
||||
cerr << "BP factor: " << BPfactor << endl;
|
||||
cerr << "Fixed clipping? " << fixedClipping << endl;
|
||||
cerr << "clipping: " << clipping << endl;
|
||||
if (learner == "mira") {
|
||||
cerr << "Optimising using Mira" << endl;
|
||||
optimiser = new MiraOptimiser(n, hildreth, marginScaleFactor, onlyViolatedConstraints, clipping, fixedClipping, slack, weightedLossFunction, maxNumberOracles, accumulateMostViolatedConstraints, pastAndCurrentConstraints, order.size());
|
||||
@ -294,8 +296,8 @@ int main(int argc, char** argv) {
|
||||
for (size_t batchPosition = 0; batchPosition < batchSize && sid != shard.end(); ++batchPosition) {
|
||||
const string& input = inputSentences[*sid];
|
||||
const vector<string>& refs = referenceSentences[*sid];
|
||||
cerr << "\nBatch position " << batchPosition << endl;
|
||||
cerr << "Input sentence " << *sid << ": \"" << input << "\"" << endl;
|
||||
cerr << "\nRank " << rank << ", batch position " << batchPosition << endl;
|
||||
cerr << "Rank " << rank << ", input sentence " << *sid << ": \"" << input << "\"" << endl;
|
||||
|
||||
vector<ScoreComponentCollection> newFeatureValues;
|
||||
vector<float> newBleuScores;
|
||||
@ -303,7 +305,7 @@ int main(int argc, char** argv) {
|
||||
bleuScores.push_back(newBleuScores);
|
||||
|
||||
// MODEL
|
||||
cerr << "Run decoder to get nbest wrt model score" << endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest wrt model score" << endl;
|
||||
vector<const Word*> bestModel = decoder->getNBest(input,
|
||||
*sid,
|
||||
n,
|
||||
@ -316,14 +318,15 @@ int main(int argc, char** argv) {
|
||||
inputLengths.push_back(decoder->getCurrentInputLength());
|
||||
ref_ids.push_back(*sid);
|
||||
decoder->cleanup();
|
||||
cerr << "Rank " << rank << ": ";
|
||||
for (size_t i = 0; i < bestModel.size(); ++i) {
|
||||
cerr << *(bestModel[i]) << " ";
|
||||
}
|
||||
cerr << endl;
|
||||
cerr << "model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
|
||||
cerr << "Rank " << rank << ", model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
|
||||
|
||||
// HOPE
|
||||
cerr << "Run decoder to get nbest hope translations" << endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest hope translations" << endl;
|
||||
size_t oraclePos = featureValues[batchPosition].size();
|
||||
oraclePositions.push_back(oraclePos);
|
||||
vector<const Word*> oracle = decoder->getNBest(input,
|
||||
@ -337,19 +340,20 @@ int main(int argc, char** argv) {
|
||||
distinctNbest);
|
||||
decoder->cleanup();
|
||||
oracles.push_back(oracle);
|
||||
cerr << "Rank " << rank << ": ";
|
||||
for (size_t i = 0; i < oracle.size(); ++i) {
|
||||
//oracles[batchPosition].push_back(oracle[i]);
|
||||
cerr << *(oracle[i]) << " ";
|
||||
}
|
||||
cerr << endl;
|
||||
cerr << "oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
cerr << "Rank " << rank << ", oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
|
||||
oracleFeatureValues.push_back(featureValues[batchPosition][oraclePos]);
|
||||
float oracleBleuScore = bleuScores[batchPosition][oraclePos];
|
||||
oracleBleuScores.push_back(oracleBleuScore);
|
||||
|
||||
// FEAR
|
||||
cerr << "Run decoder to get nbest fear translations" << endl;
|
||||
cerr << "Rank " << rank << ", run decoder to get nbest fear translations" << endl;
|
||||
size_t fearPos = featureValues[batchPosition].size();
|
||||
vector<const Word*> fear = decoder->getNBest(input,
|
||||
*sid,
|
||||
@ -361,11 +365,12 @@ int main(int argc, char** argv) {
|
||||
true,
|
||||
distinctNbest);
|
||||
decoder->cleanup();
|
||||
cerr << "Rank " << rank << ": ";
|
||||
for (size_t i = 0; i < fear.size(); ++i) {
|
||||
cerr << *(fear[i]) << " ";
|
||||
}
|
||||
cerr << endl;
|
||||
cerr << "fear length: " << fear.size() << " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
|
||||
cerr << "Rank " << rank << ", fear length: " << fear.size() << " Bleu: " << bleuScores[batchPosition][fearPos] << endl;
|
||||
|
||||
for (size_t i = 0; i < bestModel.size(); ++i) {
|
||||
delete bestModel[i];
|
||||
@ -398,17 +403,21 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
// run optimiser on batch
|
||||
cerr << "\nRun optimiser.." << endl;
|
||||
cerr << "\nRank " << rank << ", run optimiser.." << endl;
|
||||
ScoreComponentCollection oldWeights(mosesWeights);
|
||||
int constraintChange = optimiser->updateWeights(mosesWeights, featureValues, losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids);
|
||||
|
||||
// update Moses weights
|
||||
mosesWeights.L1Normalise();
|
||||
decoder->setWeights(mosesWeights);
|
||||
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
|
||||
// update history (for approximate document Bleu)
|
||||
for (size_t i = 0; i < oracles.size(); ++i) {
|
||||
cerr << "oracle length: " << oracles[i].size() << " ";
|
||||
cerr << "Rank " << rank << ", oracle length: " << oracles[i].size() << " ";
|
||||
}
|
||||
decoder->updateHistory(oracles, inputLengths, ref_ids);
|
||||
|
||||
@ -444,14 +453,14 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "\nConstraint change: " << constraintChange << endl;
|
||||
cerr << "Summed (loss - margin) with old weights: " << lossMinusMargin_old << endl;
|
||||
cerr << "Summed (loss - margin) with new weights: " << lossMinusMargin_new << endl;
|
||||
cerr << "\nRank " << rank << ", constraint change: " << constraintChange << endl;
|
||||
cerr << "Rank " << rank << ", summed (loss - margin) with old weights: " << lossMinusMargin_old << endl;
|
||||
cerr << "Rank " << rank << ", summed (loss - margin) with new weights: " << lossMinusMargin_new << endl;
|
||||
if (lossMinusMargin_new > lossMinusMargin_old) {
|
||||
cerr << "Worsening: " << lossMinusMargin_new - lossMinusMargin_old << endl;
|
||||
cerr << "Rank " << rank << ", worsening: " << lossMinusMargin_new - lossMinusMargin_old << endl;
|
||||
|
||||
if (constraintChange < 0) {
|
||||
cerr << "Something is going wrong here.." << endl;
|
||||
cerr << "Rank " << rank << ", something is going wrong here.." << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -525,7 +534,7 @@ int main(int argc, char** argv) {
|
||||
filename << "_" << weightEpochDump;
|
||||
}
|
||||
|
||||
VERBOSE(1, "Dumping weights for epoch " << epoch << " to " << filename.str() << endl);
|
||||
VERBOSE(1, "Rank 0, dumping weights for epoch " << epoch << " to " << filename.str() << endl);
|
||||
averageTotalWeights.Save(filename.str());
|
||||
|
||||
if (weightEpochDump + 1 == weightDumpFrequency){
|
||||
@ -536,7 +545,7 @@ int main(int argc, char** argv) {
|
||||
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
|
||||
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
|
||||
}
|
||||
cerr << "Rank 0: summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
cerr << "Rank 0, summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
|
||||
// compare new average weights with previous weights
|
||||
averageTotalWeightsCurrent = averageTotalWeights;
|
||||
@ -565,11 +574,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
if (reached) {
|
||||
// stop MIRA
|
||||
cerr << "\nStopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl << endl;
|
||||
cerr << "Average total weights: " << averageTotalWeights << endl;
|
||||
cerr << "\nRank 0, stopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl << endl;
|
||||
cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
now = time(0); // get current time
|
||||
tm = localtime(&now); // get struct filled out
|
||||
cerr << "End date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
|
||||
cerr << "Rank 0, end date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
|
||||
<< ", " << tm->tm_hour << ":" << tm->tm_min << ":" << tm->tm_sec << endl;
|
||||
|
||||
delete decoder;
|
||||
@ -596,7 +605,7 @@ int main(int argc, char** argv) {
|
||||
#endif
|
||||
|
||||
if (rank == 0) {
|
||||
cerr << "Average total weights: " << averageTotalWeights << endl;
|
||||
cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
}
|
||||
|
||||
now = time(0); // get current time
|
||||
|
@ -71,7 +71,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
|
||||
// iterate over all available oracles (1 if not accumulating, otherwise one per started epoch)
|
||||
for (size_t k = 0; k < m_oracles[sentenceId].size(); ++k) {
|
||||
cerr << "Oracle " << k << ": " << m_oracles[sentenceId][k] << " (BLEU: " << m_bleu_of_oracles[sentenceId][k] << ", model score: " << m_oracles[sentenceId][k].GetWeightedScore() << ")" << endl;
|
||||
//cerr << "Oracle " << k << ": " << m_oracles[sentenceId][k] << " (BLEU: " << m_bleu_of_oracles[sentenceId][k] << ", model score: " << m_oracles[sentenceId][k].GetWeightedScore() << ")" << endl;
|
||||
ScoreComponentCollection featureValueDiff = m_oracles[sentenceId][k];
|
||||
featureValueDiff.MinusEquals(featureValues[i][j]);
|
||||
float modelScoreDiff = featureValueDiff.InnerProduct(currWeights);
|
||||
@ -146,8 +146,18 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
for (size_t k = 0; k < m_featureValueDiffs.size(); ++k) {
|
||||
// compute update
|
||||
m_featureValueDiffs[k].MultiplyEquals(alphas[k]);
|
||||
cerr << "alpha: " << alphas[k] << endl;
|
||||
float update = alphas[k];
|
||||
if (m_fixedClipping) {
|
||||
if (update > m_c) {
|
||||
update = m_c;
|
||||
}
|
||||
else if (update < -1 * m_c) {
|
||||
update = -1 * m_c;
|
||||
}
|
||||
}
|
||||
|
||||
m_featureValueDiffs[k].MultiplyEquals(update);
|
||||
cerr << "alpha: " << update << endl;
|
||||
|
||||
// apply update to weight vector
|
||||
currWeights.PlusEquals(m_featureValueDiffs[k]);
|
||||
@ -179,7 +189,18 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
// compute update
|
||||
featureValueDiffs[k].MultiplyEquals(alphas[k]);
|
||||
float update = alphas[k];
|
||||
if (m_fixedClipping) {
|
||||
if (update > m_c) {
|
||||
update = m_c;
|
||||
}
|
||||
else if (update < -1 * m_c) {
|
||||
update = -1 * m_c;
|
||||
}
|
||||
}
|
||||
|
||||
featureValueDiffs[k].MultiplyEquals(update);
|
||||
cerr << "alpha: " << update << endl;
|
||||
|
||||
// apply update to weight vector
|
||||
currWeights.PlusEquals(featureValueDiffs[k]);
|
||||
|
Loading…
Reference in New Issue
Block a user