compute bleu on oracles translations of dev set, introduce params --min-weight-change, --msf-step, --msf-min, --slack-step, --slack-max, --decoder-settings

git-svn-id: http://svn.statmt.org/repository/mira@3804 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-02-24 10:54:16 +00:00 committed by Ondrej Bojar
parent c4acd339fd
commit 5312e8fc22
7 changed files with 400 additions and 174 deletions

View File

@ -43,7 +43,7 @@ namespace Mira {
return c;
}
void initMoses(const string& inifile, int debuglevel, int argc, char** argv) {
void initMoses(const string& inifile, int debuglevel, int argc, vector<string> decoder_params) {
static int BASE_ARGC = 5;
Parameter* params = new Parameter();
char ** mosesargv = new char*[BASE_ARGC + argc];
@ -56,8 +56,10 @@ namespace Mira {
mosesargv[4] = strToChar("-mbr"); //so we can do nbest
for (int i = 0; i < argc; ++i) {
mosesargv[BASE_ARGC + i] = argv[i];
char *cstr = &(decoder_params[i])[0];
mosesargv[BASE_ARGC + i] = cstr;
}
params->LoadParam(BASE_ARGC + argc,mosesargv);
StaticData::LoadDataStatic(params);
for (int i = 0; i < BASE_ARGC; ++i) {
@ -68,21 +70,16 @@ namespace Mira {
MosesDecoder::MosesDecoder(const vector<vector<string> >& refs, bool useScaledReference, bool scaleByInputLength, float BPfactor, float historySmoothing)
: m_manager(NULL) {
// force initialisation of the phrase dictionary
// force initialisation of the phrase dictionary (TODO: what for?)
const StaticData &staticData = StaticData::Instance();
// is this needed?
//m_sentence = new Sentence(Input);
//stringstream in("Initialising decoder..\n");
//const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
//m_sentence->Read(in,inputFactorOrder);
m_sentence = new Sentence(Input);
stringstream in("Initialising decoder..\n");
const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
m_sentence->Read(in,inputFactorOrder);
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
// is this needed?
//(TranslationSystem::DEFAULT);
//m_manager = new Manager(*m_sentence, staticData.GetSearchAlgorithm(), &system);
//m_manager->ProcessSentence();
m_manager = new Manager(*m_sentence, staticData.GetSearchAlgorithm(), &system);
m_manager->ProcessSentence();
// Add the bleu feature
m_bleuScoreFeature = new BleuScoreFeature(useScaledReference, scaleByInputLength, BPfactor, historySmoothing);
@ -113,8 +110,7 @@ namespace Mira {
stringstream in(source + "\n");
const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
m_sentence->Read(in,inputFactorOrder);
const TranslationSystem& system = staticData.GetTranslationSystem
(TranslationSystem::DEFAULT);
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
// set the weight for the bleu feature
ostringstream bleuWeightStr;
@ -212,5 +208,19 @@ namespace Mira {
void MosesDecoder::updateHistory(const vector< vector< const Word*> >& words, vector<size_t>& sourceLengths, vector<size_t>& ref_ids) {
m_bleuScoreFeature->UpdateHistory(words, sourceLengths, ref_ids);
}
void MosesDecoder::calculateBleuOfCorpus(const vector< vector< const Word*> >& words, vector<size_t>& ref_ids, size_t epoch) {
vector<float> bleu = m_bleuScoreFeature->CalculateBleuOfCorpus(words, ref_ids);
cerr << "\nBleu after epoch " << epoch << ": ";
if (bleu.size() > 0) {
cerr << "\nBLEU: " << bleu[4]*100 << ", "
<< bleu[3]*100 << "/" << bleu[2]*100 << "/" << bleu[1]*100 << "/" << bleu[0]*100 << " "
<< "(BP=" << bleu[5] << ", " << "ratio=" << bleu[6] << ", "
<< "hyp_len=" << bleu[7] << ", ref_len=" << bleu[8] << ")" << endl;
}
else {
cerr << "BLEU: 0" << endl;
}
}
}

View File

@ -42,7 +42,7 @@ namespace Mira {
* Initialise moses (including StaticData) using the given ini file and debuglevel, passing through any
* other command line arguments.
**/
void initMoses(const std::string& inifile, int debuglevel, int argc=0, char** argv=NULL );
void initMoses(const std::string& inifile, int debuglevel, int argc, std::vector<std::string> decoder_params);
/**
@ -67,6 +67,7 @@ class MosesDecoder {
size_t getCurrentInputLength();
void updateHistory(const std::vector<const Moses::Word*>& words);
void updateHistory(const std::vector< std::vector< const Moses::Word*> >& words, std::vector<size_t>& sourceLengths, std::vector<size_t>& ref_ids);
void calculateBleuOfCorpus(const std::vector< std::vector< const Moses::Word*> >& words, std::vector<size_t>& ref_ids, size_t epoch);
Moses::ScoreComponentCollection getWeights();
void setWeights(const Moses::ScoreComponentCollection& weights);
void cleanup();

View File

@ -24,6 +24,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <vector>
#include <boost/program_options.hpp>
#include <boost/algorithm/string.hpp>
#ifdef MPI_ENABLE
#include <boost/mpi.hpp>
namespace mpi = boost::mpi;
@ -83,6 +85,8 @@ int main(int argc, char** argv) {
size_t weightDumpFrequency;
string weightDumpStem;
float marginScaleFactor;
float marginScaleFactorStep;
float marginScaleFactorMin;
bool weightedLossFunction;
size_t n;
size_t batchSize;
@ -94,6 +98,8 @@ int main(int argc, char** argv) {
bool scaleByInputLength;
float BPfactor;
float slack;
float slack_step;
float slack_max;
size_t maxNumberOracles;
bool accumulateMostViolatedConstraints;
bool pastAndCurrentConstraints;
@ -104,6 +110,9 @@ int main(int argc, char** argv) {
size_t baseOfLog;
float clipping;
bool fixedClipping;
string decoder_settings;
float min_weight_change;
bool devBleu;
po::options_description desc("Allowed options");
desc.add_options()
("help",po::value( &help )->zero_tokens()->default_value(false), "Print this help message and exit")
@ -118,7 +127,9 @@ int main(int argc, char** argv) {
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights")
("shuffle", po::value<bool>(&shuffle)->default_value(false), "Shuffle input sentences before processing")
("hildreth", po::value<bool>(&hildreth)->default_value(true), "Use Hildreth's optimisation algorithm")
("margin-scale-factor,m", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
("msf", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
("msf-step", po::value<float>(&marginScaleFactorStep)->default_value(0), "Decrease margin scale factor iteratively by the value provided")
("msf-min", po::value<float>(&marginScaleFactorMin)->default_value(1.0), "Minimum value that margin is scaled by")
("weighted-loss-function", po::value<bool>(&weightedLossFunction)->default_value(false), "Weight the loss of a hypothesis by its Bleu score")
("nbest,n", po::value<size_t>(&n)->default_value(10), "Number of translations in nbest list")
("batch-size,b", po::value<size_t>(&batchSize)->default_value(1), "Size of batch that is send to optimiser for weight adjustments")
@ -129,7 +140,9 @@ int main(int argc, char** argv) {
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
("scale-by-input-length", po::value<bool>(&scaleByInputLength)->default_value(true), "Scale the BLEU score by a history of the input lengths")
("BP-factor", po::value<float>(&BPfactor)->default_value(1.0), "Increase penalty for short translations")
("slack", po::value<float>(&slack)->default_value(0), "Use slack in optimization problem")
("slack", po::value<float>(&slack)->default_value(0), "Use slack in optimizer")
("slack-step", po::value<float>(&slack_step)->default_value(0), "Increase slack from epoch to epoch by the value provided")
("slack-max", po::value<float>(&slack_max)->default_value(0), "Maximum slack used")
("max-number-oracles", po::value<size_t>(&maxNumberOracles)->default_value(1), "Set a maximum number of oracles to use per example")
("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false), "Accumulate most violated constraint per example")
("past-and-current-constraints", po::value<bool>(&pastAndCurrentConstraints)->default_value(false), "Accumulate most violated constraint per example and use them along all current constraints")
@ -138,8 +151,11 @@ int main(int argc, char** argv) {
("ignore-weird-updates", po::value<bool>(&ignoreWeirdUpdates)->default_value(false), "Ignore updates that increase number of violated constraints AND increase the error")
("log-feature-values", po::value<bool>(&logFeatureValues)->default_value(false), "Take log of feature values according to the given base.")
("base-of-log", po::value<size_t>(&baseOfLog)->default_value(10), "Base for log-ing feature values")
("clipping", po::value<float>(&clipping)->default_value(0.01f), "Set a threshold to regularise updates")
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold");
("clipping", po::value<float>(&clipping)->default_value(0.01), "Set a threshold to regularise updates")
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold")
("decoder-settings", po::value<string>(&decoder_settings)->default_value(""), "Decoder settings for tuning runs")
("min-weight-change", po::value<float>(&min_weight_change)->default_value(0.01), "Set minimum weight change for stopping criterion")
("dev-bleu", po::value<bool>(&devBleu)->default_value(false), "Compute BLEU score of oracle translations of the whole tuning set");
po::options_description cmdline_options;
cmdline_options.add(desc);
@ -191,7 +207,9 @@ int main(int argc, char** argv) {
}
// initialise Moses
initMoses(mosesConfigFile, verbosity);//, argc, argv);
vector<string> decoder_params;
boost::split(decoder_params, decoder_settings, boost::is_any_of("\t "));
initMoses(mosesConfigFile, verbosity, decoder_params.size(), decoder_params);
MosesDecoder* decoder = new MosesDecoder(referenceSentences, useScaledReference, scaleByInputLength, BPfactor, historySmoothing);
ScoreComponentCollection startWeights = decoder->getWeights();
startWeights.L1Normalise();
@ -260,8 +278,8 @@ int main(int argc, char** argv) {
//Main loop:
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
size_t iterations = 0;
size_t iterationsThisEpoch = 0;
size_t weightChanges = 0;
size_t weightChangesThisEpoch = 0;
time_t now = time(0); // get current time
struct tm* tm = localtime(&now); // get struct filled out
@ -270,18 +288,17 @@ int main(int argc, char** argv) {
// the result of accumulating and averaging weights over one epoch and possibly several processes
ScoreComponentCollection averageTotalWeights;
ScoreComponentCollection averageTotalWeightsCurrent;
ScoreComponentCollection averageTotalWeightsPrevious;
ScoreComponentCollection averageTotalWeightsBeforePrevious;
// TODO: scaling of feature values for probabilistic features
vector< ScoreComponentCollection> list_of_delta_h; // collect delta_h and loss for all examples of an epoch
vector< float> list_of_losses;
// print initial weights
cerr << "weights: " << decoder->getWeights() << endl;
for (size_t epoch = 0; epoch < epochs; ++epoch) {
cerr << "\nEpoch " << epoch << endl;
weightChangesThisEpoch = 0;
// Sum up weights over one epoch, final average uses weights from last epoch
iterationsThisEpoch = 0;
if (!accumulateWeights) {
cumulativeWeights.ZeroAll();
}
@ -289,6 +306,10 @@ int main(int argc, char** argv) {
// number of weight dumps this epoch
size_t weightEpochDump = 0;
// collect all oracles for dev set
vector< vector< const Word*> > allOracles;
vector<size_t> all_ref_ids;
size_t shardPosition = 0;
vector<size_t>::const_iterator sid = shard.begin();
while (sid != shard.end()) {
@ -330,6 +351,7 @@ int main(int argc, char** argv) {
rank);
inputLengths.push_back(decoder->getCurrentInputLength());
ref_ids.push_back(*sid);
all_ref_ids.push_back(*sid);
decoder->cleanup();
cerr << "Rank " << rank << ", model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
@ -350,6 +372,7 @@ int main(int argc, char** argv) {
rank);
decoder->cleanup();
oracles.push_back(oracle);
allOracles.push_back(oracle);
cerr << "Rank " << rank << ", oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
oracleFeatureValues.push_back(featureValues[batchPosition][oraclePos]);
@ -380,6 +403,8 @@ int main(int argc, char** argv) {
delete fear[i];
}
cerr << "\nRank " << rank << ", sentence " << *sid << ", Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
// next input sentence
++sid;
++actualBatchSize;
@ -421,10 +446,20 @@ int main(int argc, char** argv) {
// normalise Moses weights
mosesWeights.L1Normalise();
// print weights and features values
cerr << "\nRank " << rank << ", weights (normalised): " << mosesWeights << endl;
/* cerr << "Rank " << rank << ", feature values: " << endl;
for (size_t i = 0; i < featureValues.size(); ++i) {
for (size_t j = 0; j < featureValues[i].size(); ++j) {
cerr << featureValues[i][j].Size() << ": " << featureValues[i][j] << endl;
}
}
cerr << endl;*/
// compute difference to old weights
ScoreComponentCollection weightDifference(mosesWeights);
weightDifference.MinusEquals(oldWeights);
//cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
// update history (for approximate document Bleu)
for (size_t i = 0; i < oracles.size(); ++i) {
@ -432,82 +467,42 @@ int main(int argc, char** argv) {
}
decoder->updateHistory(oracles, inputLengths, ref_ids);
if (!devBleu) {
// clean up oracle translations after updating history
for (size_t i = 0; i < oracles.size(); ++i) {
for (size_t j = 0; j < oracles[i].size(); ++j) {
delete oracles[i][j];
}
}
// sanity check: compare margin created by old weights against new weights
float lossMinusMargin_old = 0;
float lossMinusMargin_new = 0;
for (size_t batchPosition = 0; batchPosition < actualBatchSize; ++batchPosition) {
for (size_t j = 0; j < featureValues[batchPosition].size(); ++j) {
ScoreComponentCollection featureDiff(oracleFeatureValues[batchPosition]);
featureDiff.MinusEquals(featureValues[batchPosition][j]);
// old weights
float margin = featureDiff.InnerProduct(oldWeights);
lossMinusMargin_old += (losses[batchPosition][j] - margin);
// new weights
margin = featureDiff.InnerProduct(mosesWeights);
lossMinusMargin_new += (losses[batchPosition][j] - margin);
// now collect translations of first epoch only
if (rank == 0 && epoch == 0) {
list_of_delta_h.push_back(featureDiff);
list_of_losses.push_back(losses[batchPosition][j]);
}
}
}
cerr << "\nRank " << rank << ", constraint change: " << constraintChange << endl;
cerr << "Rank " << rank << ", summed (loss - margin) with old weights: " << lossMinusMargin_old << endl;
cerr << "Rank " << rank << ", summed (loss - margin) with new weights: " << lossMinusMargin_new << endl;
bool useNewWeights = true;
if (lossMinusMargin_new > lossMinusMargin_old) {
cerr << "Rank " << rank << ", worsening: " << lossMinusMargin_new - lossMinusMargin_old << endl;
if (constraintChange < 0) {
cerr << "Rank " << rank << ", something is going wrong here.." << endl;
if (ignoreWeirdUpdates) {
useNewWeights = false;
}
}
}
if (useNewWeights) {
// set and accumulate weights
decoder->setWeights(mosesWeights);
cumulativeWeights.PlusEquals(mosesWeights);
}
else {
cerr << "Ignore new weights, keep old weights.. " << endl;
}
++iterations;
++iterationsThisEpoch;
++weightChanges;
++weightChangesThisEpoch;
// mix weights?
#ifdef MPI_ENABLE
if (shardPosition % (shard.size() / mixFrequency) == 0) {
ScoreComponentCollection averageWeights;
//if (rank == 0) {
// cerr << "Rank 0, before mixing: " << mosesWeights << endl;
//}
if (rank == 0) {
cerr << "Rank 0, before mixing: " << mosesWeights << endl;
}
//VERBOSE(1, "\nRank: " << rank << " \nBefore mixing: " << mosesWeights << endl);
VERBOSE(1, "\nRank: " << rank << " \nBefore mixing: " << mosesWeights << endl);
// collect all weights in averageWeights and divide by number of processes
mpi::reduce(world, mosesWeights, averageWeights, SCCPlus(), 0);
if (rank == 0) {
averageWeights.DivideEquals(size);
//VERBOSE(1, "After mixing: " << averageWeights << endl);
//cerr << "Rank 0, after mixing: " << averageWeights << endl;
// normalise weights after averaging
averageWeights.L1Normalise();
VERBOSE(1, "After mixing (normalised): " << averageWeights << endl);
cerr << "Rank 0, after mixing (normalised): " << averageWeights << endl;
}
// broadcast average weights from process 0
@ -520,14 +515,14 @@ int main(int argc, char** argv) {
// compute average weights per process over iterations
ScoreComponentCollection totalWeights(cumulativeWeights);
if (accumulateWeights)
totalWeights.DivideEquals(iterations);
totalWeights.DivideEquals(weightChanges);
else
totalWeights.DivideEquals(iterationsThisEpoch);
totalWeights.DivideEquals(weightChangesThisEpoch);
//if (rank == 0) {
//cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
//cerr << "Rank 0, total weights: " << totalWeights << endl;
//}
if (rank == 0) {
cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
cerr << "Rank 0, total weights: " << totalWeights << endl;
}
if (weightEpochDump + 1 == weightDumpFrequency){
// last weight dump in epoch
@ -545,10 +540,10 @@ int main(int argc, char** argv) {
#endif
if (rank == 0 && !weightDumpStem.empty()) {
// average and normalise weights
// average by number of processes and normalise weights
averageTotalWeights.DivideEquals(size);
averageTotalWeights.L1Normalise();
//cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
cerr << "Rank 0, average total weights (normalised): " << averageTotalWeights << endl;
ostringstream filename;
if (epoch < 10) {
@ -567,20 +562,14 @@ int main(int argc, char** argv) {
if (weightEpochDump + 1 == weightDumpFrequency){
// last weight dump in epoch
// compute summed error
float summedError = 0;
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
}
cerr << "Rank 0, summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
// compare new average weights with previous weights
averageTotalWeightsCurrent = averageTotalWeights;
ScoreComponentCollection firstDiff(averageTotalWeightsCurrent);
firstDiff.MinusEquals(averageTotalWeightsPrevious);
cerr << "Rank 0, weight changes since previous epoch: " << firstDiff << endl;
ScoreComponentCollection secondDiff(averageTotalWeightsCurrent);
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
cerr << "Rank 0, weight changes since before previous epoch: " << secondDiff << endl;
if (!suppressConvergence) {
// check whether stopping criterion has been reached
@ -591,7 +580,7 @@ int main(int argc, char** argv) {
FVector::const_iterator iterator1 = changes1.cbegin();
FVector::const_iterator iterator2 = changes2.cbegin();
while (iterator1 != changes1.cend()) {
if ((*iterator1).second >= 0.01 || (*iterator2).second >= 0.01) {
if (abs((*iterator1).second) >= min_weight_change || abs((*iterator2).second) >= min_weight_change) {
reached = false;
break;
}
@ -602,8 +591,7 @@ int main(int argc, char** argv) {
if (reached) {
// stop MIRA
cerr << "\nRank 0, stopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl << endl;
//cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
cerr << "\nRank 0, stopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl;
ScoreComponentCollection dummy;
ostringstream endfilename;
@ -613,6 +601,17 @@ int main(int argc, char** argv) {
#ifdef MPI_ENABLE
MPI_Abort(MPI_COMM_WORLD, 0);
#endif
if (devBleu) {
// calculate bleu score of all oracle translations of dev set
decoder->calculateBleuOfCorpus(allOracles, all_ref_ids, epoch);
}
if (marginScaleFactorStep > 0) {
cerr << "margin scale factor: " << marginScaleFactor << endl;
}
if (slack_step > 0) {
cerr << "slack: " << slack << endl;
}
goto end;
}
}
@ -623,8 +622,47 @@ int main(int argc, char** argv) {
}
}
//list_of_delta_h.clear();
//list_of_losses.clear();
if (devBleu) {
// calculate bleu score of all oracle translations of dev set
decoder->calculateBleuOfCorpus(allOracles, all_ref_ids, epoch);
// clean up oracle translations
for (size_t i = 0; i < allOracles.size(); ++i) {
for (size_t j = 0; j < allOracles[i].size(); ++j) {
delete allOracles[i][j];
}
}
}
// if using flexible margin scale factor, increase scaling (decrease value) for next epoch
if (marginScaleFactorStep > 0) {
if (marginScaleFactor - marginScaleFactorStep >= marginScaleFactorMin) {
if (typeid(*optimiser) == typeid(MiraOptimiser)) {
cerr << "old margin scale factor: " << marginScaleFactor << endl;
marginScaleFactor -= marginScaleFactorStep;
cerr << "new margin scale factor: " << marginScaleFactor << endl;
((MiraOptimiser*)optimiser)->setMarginScaleFactor(marginScaleFactor);
}
}
else {
cerr << "margin scale factor: " << marginScaleFactor << endl;
}
}
// if using flexible slack, increase slack for next epoch
if (slack_step > 0) {
if (slack + slack_step <= slack_max) {
if (typeid(*optimiser) == typeid(MiraOptimiser)) {
cerr << "old slack: " << slack << endl;
slack += slack_step;
cerr << "new slack: " << slack << endl;
((MiraOptimiser*)optimiser)->setSlack(slack);
}
}
else {
cerr << "slack: " << slack << endl;
}
}
}
end:
@ -633,13 +671,9 @@ int main(int argc, char** argv) {
MPI_Finalize();
#endif
//if (rank == 0) {
// cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
//}
now = time(0); // get current time
tm = localtime(&now); // get struct filled out
cerr << "End date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
cerr << "\nEnd date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
<< ", " << tm->tm_hour << ":" << tm->tm_min << ":" << tm->tm_sec << endl;
delete decoder;

View File

@ -63,6 +63,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
for (size_t i = 0; i < featureValues.size(); ++i) {
size_t sentenceId = sentenceIds[i];
if (m_oracles[sentenceId].size() > 1)
cerr << "Available oracles for source sentence " << sentenceId << ": " << m_oracles[sentenceId].size() << endl;
for (size_t j = 0; j < featureValues[i].size(); ++j) {
// check if optimisation criterion is violated for one hypothesis and the oracle
@ -90,6 +91,14 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
addConstraint = false;
}
/*if (modelScoreDiff < loss) {
// constraint violated
cerr << modelScoreDiff << " < " << loss << endl;
}
else {
cerr << modelScoreDiff << " >= " << loss << endl;
}*/
if (addConstraint) {
float lossMarginDistance = loss - modelScoreDiff;
@ -134,7 +143,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
cerr << "Number of constraints passed to optimiser: " << m_featureValueDiffs.size() << endl;
// cerr << "Number of constraints passed to optimiser: " << m_featureValueDiffs.size() << endl;
if (m_slack != 0) {
alphas = Hildreth::optimise(m_featureValueDiffs, m_lossMarginDistances, m_slack);
}
@ -176,10 +185,28 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
}
//cerr << "Number of violated constraints before optimisation: " << violatedConstraintsBefore << endl;
cerr << "Number of violated constraints before optimisation: " << violatedConstraintsBefore << endl;
if (featureValueDiffs.size() != 30) {
cerr << "Number of constraints passed to optimiser: " << featureValueDiffs.size() << endl;
}
if (m_slack != 0) {
/*cerr << "Feature value diffs (A): " << endl;
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
cerr << featureValueDiffs[k] << endl;
}
cerr << endl << "Loss - margin (b): " << endl;
for (size_t i = 0; i < lossMarginDistances.size(); ++i) {
cerr << lossMarginDistances[i] << endl;
}
cerr << endl;
cerr << "Slack: " << m_slack << endl;*/
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances, m_slack);
/*cerr << "Alphas: " << endl;
for (size_t i = 0; i < alphas.size(); ++i) {
cerr << alphas[i] << endl;
}
cerr << endl << endl;*/
}
else {
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances);
@ -189,23 +216,29 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
// compute update
float update = alphas[k];
float alpha = alphas[k];
if (m_fixedClipping) {
if (update > m_c) {
update = m_c;
if (alpha > m_c) {
alpha = m_c;
}
else if (update < -1 * m_c) {
update = -1 * m_c;
else if (alpha < -1 * m_c) {
alpha = -1 * m_c;
}
}
featureValueDiffs[k].MultiplyEquals(update);
cerr << "alpha: " << update << endl;
featureValueDiffs[k].MultiplyEquals(alpha);
//cerr << "alpha: " << alpha << endl;
// apply update to weight vector
currWeights.PlusEquals(featureValueDiffs[k]);
}
/*cerr << "Updates to weight vector: " << endl;
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
cerr << featureValueDiffs[k] << endl;
}
cerr << endl << endl;*/
// sanity check: how many constraints violated after optimisation?
size_t violatedConstraintsAfter = 0;
for (size_t i = 0; i < featureValues.size(); ++i) {
@ -216,15 +249,16 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
float loss = losses[i][j] * m_marginScaleFactor;
if (modelScoreDiff < loss) {
++violatedConstraintsAfter;
//cerr << modelScoreDiff << " < " << loss << endl;
}
else {
//cerr << modelScoreDiff << " >= " << loss << endl;
}
}
}
//cerr << "Number of violated constraints after optimisation: " << violatedConstraintsAfter << endl;
if (violatedConstraintsAfter > violatedConstraintsBefore) {
cerr << "Increase: " << violatedConstraintsAfter - violatedConstraintsBefore << endl << endl;
}
int constraintChange = violatedConstraintsBefore - violatedConstraintsAfter;
cerr << "Constraint change: " << constraintChange << endl;
return violatedConstraintsBefore - violatedConstraintsAfter;
}
else {

View File

@ -105,6 +105,14 @@ namespace Mira {
m_oracleIndices= oracleIndices;
}
void setSlack(float slack) {
m_slack = slack;
}
void setMarginScaleFactor(float msf) {
m_marginScaleFactor = msf;
}
private:
// number of hypotheses used for each nbest list (number of hope, fear, best model translations)
size_t m_n;

View File

@ -238,6 +238,51 @@ void BleuScoreFeature::GetNgramMatchCounts(Phrase& phrase,
}
}
void BleuScoreFeature::GetClippedNgramMatchesAndCounts(Phrase& phrase,
const NGrams& ref_ngram_counts,
std::vector< size_t >& ret_counts,
std::vector< size_t >& ret_matches,
size_t skip_first) const
{
std::map< Phrase, size_t >::const_iterator ref_ngram_counts_iter;
size_t ngram_start_idx, ngram_end_idx;
std::map<size_t, std::map<Phrase, size_t> > ngram_matches;
for (size_t end_idx = skip_first; end_idx < phrase.GetSize(); end_idx++) {
for (size_t order = 0; order < BleuScoreState::bleu_order; order++) {
if (order > end_idx) break;
ngram_end_idx = end_idx;
ngram_start_idx = end_idx - order;
Phrase ngram = phrase.GetSubString(WordsRange(ngram_start_idx, ngram_end_idx));
string ngramString = ngram.ToString();
ret_counts[order]++;
ref_ngram_counts_iter = ref_ngram_counts.find(ngram);
if (ref_ngram_counts_iter != ref_ngram_counts.end()) {
ngram_matches[order][ngram]++;
}
}
}
// clip ngram matches
for (size_t order = 0; order < BleuScoreState::bleu_order; order++) {
std::map<Phrase, size_t>::const_iterator iter;
// iterate over ngram counts for every ngram order
for (iter=ngram_matches[order].begin(); iter != ngram_matches[order].end(); ++iter) {
ref_ngram_counts_iter = ref_ngram_counts.find(iter->first);
if (iter->second > ref_ngram_counts_iter->second) {
ret_matches[order] += ref_ngram_counts_iter->second;
}
else {
ret_matches[order] += iter->second;
}
}
}
}
/*
* Given a previous state, compute Bleu score for the updated state with an additional target
* phrase translated.
@ -371,6 +416,94 @@ float BleuScoreFeature::CalculateBleu(BleuScoreState* state) const {
return precision;
}
vector<float> BleuScoreFeature::CalculateBleuOfCorpus(const vector< vector< const Word* > >& oracles, const vector<size_t>& ref_ids) {
// get ngram matches and counts for all oracle sentences and their references
vector<size_t> sumOfClippedNgramMatches(BleuScoreState::bleu_order);
vector<size_t> sumOfNgramCounts(BleuScoreState::bleu_order);
size_t ref_length = 0;
size_t target_length = 0;
for (size_t batchPosition = 0; batchPosition < oracles.size(); ++batchPosition){
Phrase phrase(Output, oracles[batchPosition]);
size_t ref_id = ref_ids[batchPosition];
size_t cur_ref_length = m_refs[ref_id].first;
NGrams cur_ref_ngrams = m_refs[ref_id].second;
ref_length += cur_ref_length;
target_length += oracles[batchPosition].size();
std::vector< size_t > ngram_counts(BleuScoreState::bleu_order);
std::vector< size_t > clipped_ngram_matches(BleuScoreState::bleu_order);
GetClippedNgramMatchesAndCounts(phrase, cur_ref_ngrams, ngram_counts, clipped_ngram_matches, 0);
// add clipped ngram matches and ngram counts to corpus sums
for (size_t i = 0; i < BleuScoreState::bleu_order; i++) {
sumOfClippedNgramMatches[i] += clipped_ngram_matches[i];
sumOfNgramCounts[i] += ngram_counts[i];
}
}
if (!sumOfNgramCounts[0]) {
vector<float> empty(0);
return empty;
}
if (!sumOfClippedNgramMatches[0]) {
vector<float> empty(0);
return empty; // if we have no unigram matches, score should be 0
}
// calculate bleu score
float precision = 1.0;
float smoothed_count, smoothed_matches;
vector<float> bleu;
// Calculate geometric mean of modified ngram precisions
// BLEU = BP * exp(SUM_1_4 1/4 * log p_n)
// = BP * 4th root(PRODUCT_1_4 p_n)
for (size_t i = 0; i < BleuScoreState::bleu_order; i++) {
if (sumOfNgramCounts[i]) {
smoothed_matches = sumOfClippedNgramMatches[i];
smoothed_count = sumOfNgramCounts[i];
if (i > 0) {
// smoothing for all n > 1
smoothed_matches += 1;
smoothed_count += 1;
}
precision *= smoothed_matches / smoothed_count;
bleu.push_back(smoothed_matches / smoothed_count);
}
else {
cerr << "no counts for order " << i+1 << endl;
}
}
// take geometric mean
precision = pow(precision, (float)1/4);
// Apply brevity penalty if applicable.
// BP = 1 if c > r
// BP = e^(1- r/c)) if c <= r
// where
// c: length of the candidate translation
// r: effective reference length (sum of best match lengths for each candidate sentence)
float BP;
if (target_length < ref_length) {
precision *= exp(1 - (1.0*ref_length/target_length));
BP = exp(1 - (1.0*ref_length/target_length));
}
else {
BP = 1.0;
}
bleu.push_back(precision);
bleu.push_back(BP);
bleu.push_back(1.0*target_length/ref_length);
bleu.push_back(target_length);
bleu.push_back(ref_length);
return bleu;
}
const FFState* BleuScoreFeature::EmptyHypothesisState(const InputType& input) const
{
return new BleuScoreState();

View File

@ -72,11 +72,17 @@ public:
std::vector< size_t >&,
std::vector< size_t >&,
size_t skip = 0) const;
void GetClippedNgramMatchesAndCounts(Phrase&,
const NGrams&,
std::vector< size_t >&,
std::vector< size_t >&,
size_t skip = 0) const;
FFState* Evaluate( const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
float CalculateBleu(BleuScoreState*) const;
std::vector<float> CalculateBleuOfCorpus(const std::vector< std::vector< const Word* > >& hypos, const std::vector<size_t>& ref_ids);
const FFState* EmptyHypothesisState(const InputType&) const;
private: