update weighted loss, adapt accumulate-most-violated-constraints and --past-and-current-constraints to batching, accumulate 1 constraint per sentence in batch, additional option for accumulating only 1 constraint per batch --one-per-batch

git-svn-id: http://svn.statmt.org/repository/mira@3878 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-04-22 19:17:33 +00:00 committed by Ondrej Bojar
parent bc7f8655d6
commit 10abc4fd62
3 changed files with 491 additions and 125 deletions

View File

@ -38,6 +38,7 @@ namespace mpi = boost::mpi;
#include "ScoreComponentCollection.h"
#include "Decoder.h"
#include "Optimiser.h"
#include "Hildreth.h"
using namespace Mira;
using namespace std;
@ -80,10 +81,39 @@ bool evaluateModulo(size_t shard_position, size_t mix_or_dump_base, size_t actua
struct RandomIndex {
ptrdiff_t operator()(ptrdiff_t max) {
srand(time(0)); // Initialize random number generator with current time.
return static_cast<ptrdiff_t> (rand() % max);
}
};
void shuffleInput(vector<size_t>& order, size_t size, size_t inputSize) {
cerr << "Shuffling input examples.." << endl;
// RandomIndex rindex;
// random_shuffle(order.begin(), order.end(), rindex);
// remove first element and put it in the back
size_t first = order.at(0);
size_t index = 0;
order.erase(order.begin());
order.push_back(first);
}
void createShard(vector<size_t>& order, size_t size, size_t rank, vector<size_t>& shard) {
// Create the shards according to the number of processes used
float shardSize = (float) (order.size()) / size;
size_t shardStart = (size_t) (shardSize * rank);
size_t shardEnd = (size_t) (shardSize * (rank + 1));
if (rank == size - 1)
shardEnd = order.size();
shard.resize(shardSize);
copy(order.begin() + shardStart, order.begin() + shardEnd, shard.begin());
cerr << "order: ";
for (size_t i = 0; i < shard.size(); ++i) {
cerr << shard[i] << " ";
}
cerr << endl;
}
int main(int argc, char** argv) {
size_t rank = 0;
size_t size = 1;
@ -143,11 +173,14 @@ int main(int argc, char** argv) {
bool devBleu;
bool normaliseWeights;
bool one_constraint;
bool one_per_batch;
bool print_feature_values;
bool stop_dev_bleu;
bool stop_approx_dev_bleu;
bool stop_optimal;
bool train_linear_classifier;
int updates_per_epoch;
bool multiplyA;
po::options_description desc("Allowed options");
desc.add_options()("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false),"Accumulate most violated constraint per example")
("accumulate-weights", po::value<bool>(&accumulateWeights)->default_value(false), "Accumulate and average weights over all epochs")
@ -181,9 +214,11 @@ int main(int argc, char** argv) {
("msf", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
("msf-min", po::value<float>(&marginScaleFactorMin)->default_value(1.0), "Minimum value that margin is scaled by")
("msf-step", po::value<float>(&marginScaleFactorStep)->default_value(0), "Decrease margin scale factor iteratively by the value provided")
("nbest,n", po::value<size_t>(&n)->default_value(10), "Number of translations in nbest list")
("multiplyA", po::value<bool>(&multiplyA)->default_value(true), "Multiply A with outcome before passing to Hildreth")
("nbest,n", po::value<size_t>(&n)->default_value(10), "Number of translations in nbest list")
("normalise", po::value<bool>(&normaliseWeights)->default_value(false), "Whether to normalise the updated weights before passing them to the decoder")
("one-constraint", po::value<bool>(&one_constraint)->default_value(false), "Forget about hope and fear and consider only the 1best model translation to formulate a constraint")
("one-per-batch", po::value<bool>(&one_per_batch)->default_value(false), "Only 1 constraint per batch for params --accumulate-most-violated.. and --past-and-current..")
("only-violated-constraints", po::value<bool>(&onlyViolatedConstraints)->default_value(false), "Add only violated constraints to the optimisation problem")
("past-and-current-constraints", po::value<bool>(&pastAndCurrentConstraints)->default_value(false), "Accumulate most violated constraint per example and use them along all current constraints")
("print-feature-values", po::value<bool>(&print_feature_values)->default_value(false), "Print out feature values")
@ -197,6 +232,7 @@ int main(int argc, char** argv) {
("stop-approx-dev-bleu", po::value<bool>(&stop_approx_dev_bleu)->default_value(false), "Stop when average approx. sentence Bleu (dev) decreases (or no more increases)")
("stop-weights", po::value<bool>(&weightConvergence)->default_value(false), "Stop when weights converge")
("stop-optimal", po::value<bool>(&stop_optimal)->default_value(true), "Stop when the results of optimization do not improve further")
("train-linear-classifier", po::value<bool>(&train_linear_classifier)->default_value(false), "Test algorithm for linear classification")
("updates-per-epoch", po::value<int>(&updates_per_epoch)->default_value(-1), "Accumulate updates and apply them to the weight vector the specified number of times per epoch")
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
@ -210,6 +246,311 @@ int main(int argc, char** argv) {
po::command_line_parser(argc, argv). options(cmdline_options).run(), vm);
po::notify(vm);
if (train_linear_classifier) {
FName name_x("x");
FName name_y("y");
FVector weights;
weights.set(name_x, 0);
weights.set(name_y, 0);
vector<FVector> examples;
vector<int> outcomes;
FVector pos1;
pos1.set(name_x, 1);
pos1.set(name_y, 1);
FVector pos2;
pos2.set(name_x, 2);
pos2.set(name_y, 1);
FVector pos3;
pos3.set(name_x, 3);
pos3.set(name_y, 1);
FVector pos4;
pos4.set(name_x, 8);
pos4.set(name_y, 1);
FVector neg1;
neg1.set(name_x, 1);
neg1.set(name_y, -1);
FVector neg2;
neg2.set(name_x, 2);
neg2.set(name_y, -1);
FVector neg3;
neg3.set(name_x, 3);
neg3.set(name_y, -1);
FVector neg4;
neg4.set(name_x, 8);
neg4.set(name_y, -1);
examples.push_back(pos1);
examples.push_back(neg1);
examples.push_back(pos2);
examples.push_back(neg2);
examples.push_back(pos3);
examples.push_back(neg3);
examples.push_back(pos4);
examples.push_back(neg4);
outcomes.push_back(1);
outcomes.push_back(-1);
outcomes.push_back(1);
outcomes.push_back(-1);
outcomes.push_back(1);
outcomes.push_back(-1);
outcomes.push_back(1);
outcomes.push_back(-1);
// add outlier
FVector pos5;
pos5.set(name_x, 2.5);
pos5.set(name_y, -1.5);
examples.push_back(pos5);
outcomes.push_back(1);
FVector neg5;
neg5.set(name_x, 0);
neg5.set(name_y, -1);
examples.push_back(neg5);
outcomes.push_back(-1);
// create order
vector<size_t> order;
if (rank == 0) {
for (size_t i = 0; i < examples.size(); ++i) {
order.push_back(i);
}
}
cerr << "weights: " << weights << endl;
cerr << "slack: " << slack << endl;
bool stop = false;
FVector prevFinalWeights;
FVector prevPrevFinalWeights;
// float epsilon = 0.0001;
float epsilon = 0.001;
for (size_t epoch = 0; !stop && epoch < epochs; ++epoch) {
cerr << "\nEpoch " << epoch << endl;
size_t updatesThisEpoch = 0;
// optionally shuffle the input data --> change order
if (shuffle && rank == 0) {
shuffleInput(order, size, examples.size());
}
// create shard
vector<size_t> shard;
createShard(order, size, rank, shard);
cerr << "\nBefore: y * x_i . w >= 1 ?" << endl;
size_t numberUnsatisfiedConstraintsBefore = 0;
float errorSumBefore = 0;
// outcome_i (x_i . w + b) >= +1 (b = 0 in this example)
for (size_t i = 0; i < examples.size(); ++i) {
float innerProduct = examples[i].inner_product(weights);
float leftHandSide = outcomes[i] * innerProduct;
cerr << outcomes[i] << " * " << innerProduct << " >= " << 1 << "? " << (leftHandSide >= 1) << " (error: " << 1 - leftHandSide << ")" << endl;
float diff = 1 - leftHandSide;
if (diff > epsilon) {
++numberUnsatisfiedConstraintsBefore;
errorSumBefore += 1 - leftHandSide;
}
}
cerr << "unsatisfied constraints before: " << numberUnsatisfiedConstraintsBefore << endl;
cerr << "error sum before: " << errorSumBefore << endl;
// iterate over training data
size_t shardPosition = 0;
vector<size_t>::const_iterator sid = shard.begin();
// Use updatedWeights during one epoch
FVector updatedWeights(weights);
while (sid != shard.end()) {
vector<FVector> A;
vector<FVector> A_mult;
vector<float> b;
vector<float> batch_outcomes;
// collect constraints for batch
size_t actualBatchSize = 0;
for (size_t batchPosition = 0; batchPosition < batchSize && sid != shard.end(); ++batchPosition) {
A.push_back(examples[*sid]);
A_mult.push_back(outcomes[*sid]* examples[*sid]);
b.push_back(1);
batch_outcomes.push_back(outcomes[*sid]);
// next input sentence
++sid;
++actualBatchSize;
++shardPosition;
}
// check if constraints are already satisfied for batch
size_t batch_numberUnsatisfiedConstraintsBefore = 0;
float batch_errorSumBefore = 0;
for (size_t i = 0; i < A.size(); ++i) {
float innerProduct = A[i].inner_product(updatedWeights);
float leftHandSide = batch_outcomes[i] * innerProduct;
cerr << batch_outcomes[i] << " * " << innerProduct << " >= " << 1 << "? " << (leftHandSide >= 1) << " (error: " << 1 - leftHandSide << ")" << endl;
float diff = 1 - leftHandSide;
if (diff > epsilon) {
++batch_numberUnsatisfiedConstraintsBefore;
batch_errorSumBefore += 1 - leftHandSide;
}
}
cerr << "batch: unsatisfied constraints before: " << batch_numberUnsatisfiedConstraintsBefore << endl;
cerr << "batch: error sum before: " << batch_errorSumBefore << endl;
if (batch_numberUnsatisfiedConstraintsBefore != 0) {
// pass constraints to optimizer
for (size_t j = 0; j < A.size(); ++j){
if (multiplyA) {
cerr << "A: " << A_mult[j] << endl;
}
else {
cerr << "A: " << A[j] << endl;
}
}
vector<float> alphas;
if (slack == 0) {
if (multiplyA) {
alphas = Hildreth::optimise(A_mult, b);
}
else {
alphas = Hildreth::optimise(A, b);
}
}
else {
if (multiplyA) {
alphas = Hildreth::optimise(A_mult, b, slack);
}
else{
alphas = Hildreth::optimise(A, b, slack);
}
}
for (size_t j = 0; j < alphas.size(); ++j) {
cerr << "alpha:" << alphas[j] << endl;
updatedWeights += batch_outcomes[j] * A[j] * alphas[j];
}
cerr << "potential new weights: " << updatedWeights << endl;
// check if constraints are satisfied after processing batch
size_t batch_numberUnsatisfiedConstraintsAfter = 0;
float batch_errorSumAfter = 0;
for (size_t i = 0; i < A.size(); ++i) {
float innerProduct = A[i].inner_product(updatedWeights);
float leftHandSide = batch_outcomes[i] * innerProduct;
cerr << batch_outcomes[i] << " * " << innerProduct << " >= " << 1 << "? " << (leftHandSide >= 1) << " (error: " << 1 - leftHandSide << ")" << endl;
float diff = 1 - leftHandSide;
if (diff > epsilon) {
++batch_numberUnsatisfiedConstraintsAfter;
batch_errorSumAfter += 1 - leftHandSide;
}
}
cerr << "batch: unsatisfied constraints after: " << batch_numberUnsatisfiedConstraintsAfter << endl;
cerr << "batch: error sum after: " << batch_errorSumAfter << endl << endl;
}
else {
cerr << "batch: all constraints satisfied." << endl;
}
} // end of epoch */
cerr << "After: y * x_i . w >= 1 ?" << endl;
size_t numberUnsatisfiedConstraintsAfter = 0;
float errorSumAfter = 0;
for (size_t i = 0; i < examples.size(); ++i) {
float innerProduct = examples[i].inner_product(updatedWeights);
float leftHandSide = outcomes[i] * innerProduct;
cerr << outcomes[i] << " * " << innerProduct << " >= " << 1 << "? " << (leftHandSide >= 1) << " (error: " << 1 - leftHandSide << ")" << endl;
float diff = 1 - leftHandSide;
if (diff > epsilon) {
++numberUnsatisfiedConstraintsAfter;
errorSumAfter += 1 - leftHandSide;
}
}
cerr << "unsatisfied constraints after: " << numberUnsatisfiedConstraintsAfter << endl;
cerr << "error sum after: " << errorSumAfter << endl;
float epsilon = 0.0001;
float diff = errorSumAfter - errorSumBefore;
if (numberUnsatisfiedConstraintsAfter == 0) {
weights = updatedWeights;
cerr << "All constraints satisfied during this epoch, stop." << endl;
cerr << "new weights: " << weights << endl;
++updatesThisEpoch;
stop = true;
}
else if (numberUnsatisfiedConstraintsAfter < numberUnsatisfiedConstraintsBefore) {
cerr << "Constraints improved during this epoch." << endl;
weights = updatedWeights;
cerr << "new weights: " << weights << endl;
++updatesThisEpoch;
}
else if (numberUnsatisfiedConstraintsAfter == numberUnsatisfiedConstraintsBefore && errorSumAfter < errorSumBefore) {
cerr << "Error improved during this epoch." << endl;
weights = updatedWeights;
cerr << "new weights: " << weights << endl;
++updatesThisEpoch;
}
else if(numberUnsatisfiedConstraintsAfter == numberUnsatisfiedConstraintsBefore && (diff < epsilon && diff > epsilon * -1)) {
cerr << "No changes to constraints or error during this epoch." << endl;
}
else {
cerr << "Constraints/error got worse during this epoch." << endl;
}
if (updatesThisEpoch == 0) {
stop = true;
cerr << "No more updates, stop." << endl;
}
else if(prevFinalWeights == weights) {
stop = true;
cerr << "Final weights not changing anymore, stop." << endl;
}
else if(prevPrevFinalWeights == weights) {
stop = true;
cerr << "Final weights changing back to previous final weights, take average and stop." << endl;
weights = prevFinalWeights;
weights += prevPrevFinalWeights;
weights /= 2;
}
prevPrevFinalWeights = prevFinalWeights;
prevFinalWeights = weights;
}
cerr << "\nFinal: " << endl;
cerr << weights << endl;
// classify new examples
cerr << "\nTest examples:" << endl;
FVector test_pos1;
test_pos1.set(name_x, 7);
test_pos1.set(name_y, 1);
cerr << "pos1 (7, 1): " << test_pos1.inner_product(weights) << endl;
FVector test_pos2;
test_pos2.set(name_x, 6);
test_pos2.set(name_y, 2);
cerr << "pos2 (2, 2): " << test_pos2.inner_product(weights) << endl;
/* FVector test_pos3;
test_pos3.set(name_x, 1);
test_pos3.set(name_y, 0.5);
cerr << "pos3 (1, 0.5): " << test_pos3.inner_product(weights) << endl;*/
FVector test_neg1;
test_neg1.set(name_x, 7);
test_neg1.set(name_y, -1);
cerr << "neg1 (7, -1): " << test_neg1.inner_product(weights) << endl;
FVector test_neg2;
test_neg2.set(name_x, 6);
test_neg2.set(name_y, -2);
cerr << "neg2 (2, -2): " << test_neg2.inner_product(weights) << endl;
/* FVector test_neg3;
test_neg3.set(name_x, 1);
test_neg3.set(name_y, -0.5);
cerr << "neg3 (1, -0.5): " << test_neg3.inner_product(weights) << endl;*/
exit(0);
}
if (help) {
std::cout << "Usage: " + string(argv[0])
+ " -f mosesini-file -i input-file -r reference-file(s) [options]"
@ -233,6 +574,11 @@ int main(int argc, char** argv) {
return 1;
}
if (accumulateMostViolatedConstraints && pastAndCurrentConstraints) {
cerr << "Error: the parameters --accumulate-most-violated-constraints and --past-and-current-constraints are mutually exclusive" << endl;
return 1;
}
// load input and references
vector<string> inputSentences;
if (!loadSentences(inputFile, inputSentences)) {
@ -345,7 +691,7 @@ int main(int argc, char** argv) {
cerr << "Optimising using Mira" << endl;
optimiser = new MiraOptimiser(n, hildreth, marginScaleFactor,
onlyViolatedConstraints, slack, weightedLossFunction, maxNumberOracles,
accumulateMostViolatedConstraints, pastAndCurrentConstraints,
accumulateMostViolatedConstraints, pastAndCurrentConstraints, one_per_batch,
order.size());
if (hildreth) {
cerr << "Using Hildreth's optimisation algorithm.." << endl;
@ -521,7 +867,7 @@ int main(int argc, char** argv) {
}
}
cerr << "Rank " << rank << ", " << *sid << ", best model Bleu (approximate sentence bleu): " << bleuScores[batchPosition][0] << endl;
cerr << "Rank " << rank << ", sentence " << *sid << ", best model Bleu (approximate sentence bleu): " << bleuScores[batchPosition][0] << endl;
summedApproxBleu += bleuScores[batchPosition][0];
// next input sentence

View File

@ -36,10 +36,16 @@ vector<int> MiraOptimiser::updateWeightsAnalytically(ScoreComponentCollection& c
// compute alpha for given constraint: (loss - model score diff) / || feature value diff ||^2
// featureValueDiff.GetL2Norm() * featureValueDiff.GetL2Norm() == featureValueDiff.InnerProduct(featureValueDiff)
// from Crammer&Singer 2006: alpha = min {C , l_t/ ||x||^2}
cerr << "Rank " << rank << ", epoch " << epoch << ", feature value diff: " << featureValueDiff << endl;
float squaredNorm = featureValueDiff.GetL2Norm() * featureValueDiff.GetL2Norm();
if (squaredNorm > 0) {
float alpha = (lossMinusModelScoreDiff) / squaredNorm;
if (m_slack > 0 && alpha > m_slack) {
alpha = m_slack;
}
cerr << "Rank " << rank << ", epoch " << epoch << ", alpha: " << alpha << endl;
featureValueDiff.MultiplyEquals(alpha);
weightUpdate.PlusEquals(featureValueDiff);
@ -142,7 +148,7 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
for (size_t j = 0; j < m_oracles[sentenceId].size(); ++j) {
float currentWeightedScore = m_oracles[sentenceId][j].GetWeightedScore();
if (currentWeightedScore == newWeightedScore) {
cerr << "updated.." << endl;
cerr << "Rank " << rank << ", Bleu score of oracle updated at batch position " << i << ", " << m_bleu_of_oracles[sentenceId][j] << " --> " << oracleBleuScores[j] << endl;
m_bleu_of_oracles[sentenceId][j] = oracleBleuScores[j];
updated = true;
break;
@ -170,9 +176,13 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
int violatedConstraintsBefore = 0;
vector<float> lossMinusModelScoreDiffs;
// find most violated constraint
float maxViolationLossMarginDistance;
ScoreComponentCollection maxViolationfeatureValueDiff;
// find most violated constraint for each input sentence in batch
vector<float> max_lossMarginDistance(featureValues.size());
vector<ScoreComponentCollection> max_featureValueDiff(featureValues.size());
// find most violated constraint in batch
float max_batch_lossMarginDistance = 0;
ScoreComponentCollection max_batch_featureValueDiff;
float epsilon = 0.0001;
float oldDistanceFromOptimum = 0;
@ -196,7 +206,7 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
float modelScoreDiff = featureValueDiff.InnerProduct(currWeights);
float loss = losses[i][j] * m_marginScaleFactor;
if (m_weightedLossFunction) {
loss *= log10(bleuScores[i][j]);
loss *= bleuScores[i][j];
}
// cerr << "Rank " << rank << ", model score diff: " << modelScoreDiff << ", loss: " << loss << endl;
@ -216,30 +226,33 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
if (addConstraint) {
float lossMinusModelScoreDiff = loss - modelScoreDiff;
if (m_accumulateMostViolatedConstraints
&& !m_pastAndCurrentConstraints) {
if (lossMinusModelScoreDiff > maxViolationLossMarginDistance) {
maxViolationLossMarginDistance = lossMinusModelScoreDiff;
maxViolationfeatureValueDiff = featureValueDiff;
}
} else if (m_pastAndCurrentConstraints) {
if (lossMinusModelScoreDiff > maxViolationLossMarginDistance) {
maxViolationLossMarginDistance = lossMinusModelScoreDiff;
maxViolationfeatureValueDiff = featureValueDiff;
}
featureValueDiffs.push_back(featureValueDiff);
lossMinusModelScoreDiffs.push_back(lossMinusModelScoreDiff);
} else {
// Objective: 1/2 * ||w' - w||^2 + C * SUM_1_m[ max_1_n (l_ij - Delta_h_ij.w')]
// To add a constraint for the optimiser for each sentence i and hypothesis j, we need:
// 1. vector Delta_h_ij of the feature value differences (oracle - hypothesis)
// 2. loss_ij - difference in model scores (Delta_h_ij.w') (oracle - hypothesis)
featureValueDiffs.push_back(featureValueDiff);
// cerr << "feature value diff (A): " << featureValueDiff << endl;
lossMinusModelScoreDiffs.push_back(lossMinusModelScoreDiff);
// cerr << "loss - model score diff (b): " << lossMarginDistance << endl << endl;
if (m_accumulateMostViolatedConstraints || m_pastAndCurrentConstraints) {
if (m_one_per_batch) {
// accumulate the most violated constraint per batch
if (lossMinusModelScoreDiff > max_batch_lossMarginDistance) {
max_batch_lossMarginDistance = lossMinusModelScoreDiff;
max_batch_featureValueDiff = featureValueDiff;
}
}
else {
// accumulate the most violated constraint for each input sentence
if (lossMinusModelScoreDiff > max_lossMarginDistance[i]) {
max_lossMarginDistance[i] = lossMinusModelScoreDiff;
max_featureValueDiff[i] = featureValueDiff;
}
}
}
if (!m_accumulateMostViolatedConstraints) {
// Objective: 1/2 * ||w' - w||^2 + C * SUM_1_m[ max_1_n (l_ij - Delta_h_ij.w')]
// To add a constraint for the optimiser for each sentence i and hypothesis j, we need:
// 1. vector Delta_h_ij of the feature value differences (oracle - hypothesis)
// 2. loss_ij - difference in model scores (Delta_h_ij.w') (oracle - hypothesis)
featureValueDiffs.push_back(featureValueDiff);
// cerr << "feature value diff (A): " << featureValueDiff << endl;
lossMinusModelScoreDiffs.push_back(lossMinusModelScoreDiff);
// cerr << "loss - model score diff (b): " << lossMarginDistance << endl << endl;
}
}
}
}
@ -252,108 +265,103 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
}
}
if (featureValueDiffs.size() != 30) {
cerr << "Rank " << rank << ", number of constraints passed to optimiser: "
<< featureValueDiffs.size() << endl;
}
// run optimisation: compute alphas for all given constraints
vector<float> alphas;
ScoreComponentCollection summedUpdate;
if (m_accumulateMostViolatedConstraints && !m_pastAndCurrentConstraints) {
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
if (m_slack != 0) {
alphas = Hildreth::optimise(m_featureValueDiffs, m_lossMarginDistances, m_slack);
} else {
alphas = Hildreth::optimise(m_featureValueDiffs, m_lossMarginDistances);
}
cerr << "m_featureValueDiffs.size: " << m_featureValueDiffs.size() << endl;
// Update the weight vector according to the alphas and the feature value differences
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
for (size_t k = 0; k < m_featureValueDiffs.size(); ++k) {
float alpha = alphas[k];
m_featureValueDiffs[k].MultiplyEquals(alpha);
// sum up update
summedUpdate.PlusEquals(m_featureValueDiffs[k]);
}
} else if (violatedConstraintsBefore > 0) {
if (m_pastAndCurrentConstraints) {
// add all (most violated) past constraints to the list of current constraints
for (size_t i = 0; i < m_featureValueDiffs.size(); ++i) {
featureValueDiffs.push_back(m_featureValueDiffs[i]);
lossMinusModelScoreDiffs.push_back(m_lossMarginDistances[i]);
}
// add new most violated constraint to list
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
}
if (m_slack != 0) {
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs, m_slack);
} else {
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs);
}
// for (size_t i=0; i < alphas.size(); ++i) {
// cerr << "alpha: " << alphas[i] << endl;
// }
// Update the weight vector according to the alphas and the feature value differences
// * w' = w' + SUM alpha_i * (h_i(oracle) - h_i(hypothesis))
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
float alpha = alphas[k];
featureValueDiffs[k].MultiplyEquals(alpha);
// sum up update
summedUpdate.PlusEquals(featureValueDiffs[k]);
}
} else {
cerr << "Rank " << rank << ", epoch " << epoch << ", check, no constraint violated for this batch" << endl;
vector<int> status(3);
status[0] = 1;
status[1] = 0;
status[2] = 0;
return status;
if (m_pastAndCurrentConstraints) {
// add all past (most violated) constraints to the list of current constraints
for (size_t i = 0; i < m_featureValueDiffs.size(); ++i) {
featureValueDiffs.push_back(m_featureValueDiffs[i]);
lossMinusModelScoreDiffs.push_back(m_lossMarginDistances[i]);
}
}
if (m_accumulateMostViolatedConstraints || m_pastAndCurrentConstraints) {
// add all new most violated constraints (per input sentence in batch) to the list
if (m_one_per_batch) {
m_featureValueDiffs.push_back(max_batch_featureValueDiff);
m_lossMarginDistances.push_back(max_batch_lossMarginDistance);
}
else{
for (size_t i = 0; i < max_featureValueDiff.size(); ++i) {
m_featureValueDiffs.push_back(max_featureValueDiff[i]);
m_lossMarginDistances.push_back(max_lossMarginDistance[i]);
}
}
}
if (m_accumulateMostViolatedConstraints) {
featureValueDiffs = m_featureValueDiffs;
lossMinusModelScoreDiffs = m_lossMarginDistances;
}
if (violatedConstraintsBefore > 0) {
cerr << "Rank " << rank << ", number of constraints passed to optimizer: " << featureValueDiffs.size() << endl;
if (m_slack != 0) {
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs, m_slack);
} else {
alphas = Hildreth::optimise(featureValueDiffs, lossMinusModelScoreDiffs);
}
// Update the weight vector according to the alphas and the feature value differences
// * w' = w' + SUM alpha_i * (h_i(oracle) - h_i(hypothesis))
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
float alpha = alphas[k];
featureValueDiffs[k].MultiplyEquals(alpha);
// sum up update
summedUpdate.PlusEquals(featureValueDiffs[k]);
}
}
else {
cerr << "Rank " << rank << ", epoch " << epoch << ", check, no constraint violated for this batch" << endl;
vector<int> status(3);
status[0] = 1;
status[1] = 0;
status[2] = 0;
return status;
}
// sanity check: still violated constraints after optimisation?
ScoreComponentCollection newWeights(currWeights);
newWeights.PlusEquals(summedUpdate);
if (updates_per_epoch > 0) {
newWeights.PlusEquals(m_accumulatedUpdates);
newWeights.PlusEquals(m_accumulatedUpdates);
}
int violatedConstraintsAfter = 0;
float newDistanceFromOptimum = 0;
for (size_t i = 0; i < featureValues.size(); ++i) {
for (size_t j = 0; j < featureValues[i].size(); ++j) {
ScoreComponentCollection featureValueDiff = oracleFeatureValues[i];
featureValueDiff.MinusEquals(featureValues[i][j]);
float modelScoreDiff = featureValueDiff.InnerProduct(newWeights);
float loss = losses[i][j] * m_marginScaleFactor;
// cerr << "Rank " << rank << ", new model score diff: " << modelScoreDiff << ", loss: " << loss << endl;
float diff = loss - modelScoreDiff;
// approximate comparison between floats!
if (diff > epsilon) {
++violatedConstraintsAfter;
newDistanceFromOptimum += (loss - modelScoreDiff);
}
}
}
cerr << "Rank " << rank << ", epoch " << epoch << ", check, violated constraint before: " << violatedConstraintsBefore << ", after: " << violatedConstraintsAfter << ", change: " << violatedConstraintsBefore - violatedConstraintsAfter << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", check, error before: " << oldDistanceFromOptimum << ", after: " << newDistanceFromOptimum << ", change: " << oldDistanceFromOptimum - newDistanceFromOptimum << endl;
if (violatedConstraintsAfter > 0) {
float distanceChange = oldDistanceFromOptimum - newDistanceFromOptimum;
if (controlUpdates && (violatedConstraintsBefore - violatedConstraintsAfter) < 0 && distanceChange < 0) {
vector<int> statusPlus(3);
statusPlus[0] = -1;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
}
// sanity check: still violated constraints after optimisation?
if (!m_accumulateMostViolatedConstraints && !m_pastAndCurrentConstraints) { // constraint checking not implemented for these cases
for (size_t i = 0; i < featureValues.size(); ++i) {
for (size_t j = 0; j < featureValues[i].size(); ++j) {
ScoreComponentCollection featureValueDiff = oracleFeatureValues[i];
featureValueDiff.MinusEquals(featureValues[i][j]);
float modelScoreDiff = featureValueDiff.InnerProduct(newWeights);
float loss = losses[i][j] * m_marginScaleFactor;
// cerr << "Rank " << rank << ", new model score diff: " << modelScoreDiff << ", loss: " << loss << endl;
float diff = loss - modelScoreDiff;
// approximate comparison between floats!
if (diff > epsilon) {
++violatedConstraintsAfter;
newDistanceFromOptimum += (loss - modelScoreDiff);
}
}
}
cerr << "Rank " << rank << ", epoch " << epoch << ", check, violated constraint before: " << violatedConstraintsBefore << ", after: " << violatedConstraintsAfter << ", change: " << violatedConstraintsBefore - violatedConstraintsAfter << endl;
cerr << "Rank " << rank << ", epoch " << epoch << ", check, error before: " << oldDistanceFromOptimum << ", after: " << newDistanceFromOptimum << ", change: " << oldDistanceFromOptimum - newDistanceFromOptimum << endl;
if (violatedConstraintsAfter > 0) {
float distanceChange = oldDistanceFromOptimum - newDistanceFromOptimum;
if (controlUpdates && (violatedConstraintsBefore - violatedConstraintsAfter) < 0 && distanceChange < 0) {
vector<int> statusPlus(3);
statusPlus[0] = -1;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
}
}
}
// apply learning rate (fixed or flexible)
@ -382,11 +390,20 @@ vector<int> MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
}
vector<int> statusPlus(3);
statusPlus[0] = 0;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
if (!m_accumulateMostViolatedConstraints && !m_pastAndCurrentConstraints) {
statusPlus[0] = 0;
statusPlus[1] = violatedConstraintsBefore;
statusPlus[2] = violatedConstraintsAfter;
return statusPlus;
}
else {
statusPlus[0] = 0;
statusPlus[1] = -1;
statusPlus[2] = -1;
return statusPlus;
}
}
}

View File

@ -88,7 +88,7 @@ namespace Mira {
MiraOptimiser() :
Optimiser() { }
MiraOptimiser(size_t n, bool hildreth, float marginScaleFactor, bool onlyViolatedConstraints, float slack, bool weightedLossFunction, size_t maxNumberOracles, bool accumulateMostViolatedConstraints, bool pastAndCurrentConstraints, size_t exampleSize) :
MiraOptimiser(size_t n, bool hildreth, float marginScaleFactor, bool onlyViolatedConstraints, float slack, bool weightedLossFunction, size_t maxNumberOracles, bool accumulateMostViolatedConstraints, bool pastAndCurrentConstraints, bool one_per_batch, size_t exampleSize) :
Optimiser(),
m_n(n),
m_hildreth(hildreth),
@ -99,6 +99,7 @@ namespace Mira {
m_max_number_oracles(maxNumberOracles),
m_accumulateMostViolatedConstraints(accumulateMostViolatedConstraints),
m_pastAndCurrentConstraints(pastAndCurrentConstraints),
m_one_per_batch(one_per_batch),
m_oracles(exampleSize),
m_bleu_of_oracles(exampleSize) { }
@ -185,6 +186,8 @@ namespace Mira {
bool m_pastAndCurrentConstraints;
bool m_one_per_batch;
Moses::ScoreComponentCollection m_accumulatedUpdates;
};
}