mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 13:23:25 +03:00
compute bleu on oracles translations of dev set, introduce params --min-weight-change, --msf-step, --msf-min, --slack-step, --slack-max, --decoder-settings
git-svn-id: http://svn.statmt.org/repository/mira@3804 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
c4acd339fd
commit
5312e8fc22
@ -43,7 +43,7 @@ namespace Mira {
|
||||
return c;
|
||||
}
|
||||
|
||||
void initMoses(const string& inifile, int debuglevel, int argc, char** argv) {
|
||||
void initMoses(const string& inifile, int debuglevel, int argc, vector<string> decoder_params) {
|
||||
static int BASE_ARGC = 5;
|
||||
Parameter* params = new Parameter();
|
||||
char ** mosesargv = new char*[BASE_ARGC + argc];
|
||||
@ -56,8 +56,10 @@ namespace Mira {
|
||||
mosesargv[4] = strToChar("-mbr"); //so we can do nbest
|
||||
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
mosesargv[BASE_ARGC + i] = argv[i];
|
||||
char *cstr = &(decoder_params[i])[0];
|
||||
mosesargv[BASE_ARGC + i] = cstr;
|
||||
}
|
||||
|
||||
params->LoadParam(BASE_ARGC + argc,mosesargv);
|
||||
StaticData::LoadDataStatic(params);
|
||||
for (int i = 0; i < BASE_ARGC; ++i) {
|
||||
@ -68,21 +70,16 @@ namespace Mira {
|
||||
|
||||
MosesDecoder::MosesDecoder(const vector<vector<string> >& refs, bool useScaledReference, bool scaleByInputLength, float BPfactor, float historySmoothing)
|
||||
: m_manager(NULL) {
|
||||
// force initialisation of the phrase dictionary
|
||||
// force initialisation of the phrase dictionary (TODO: what for?)
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
|
||||
// is this needed?
|
||||
//m_sentence = new Sentence(Input);
|
||||
//stringstream in("Initialising decoder..\n");
|
||||
//const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
|
||||
//m_sentence->Read(in,inputFactorOrder);
|
||||
m_sentence = new Sentence(Input);
|
||||
stringstream in("Initialising decoder..\n");
|
||||
const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
|
||||
m_sentence->Read(in,inputFactorOrder);
|
||||
|
||||
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
|
||||
|
||||
// is this needed?
|
||||
//(TranslationSystem::DEFAULT);
|
||||
//m_manager = new Manager(*m_sentence, staticData.GetSearchAlgorithm(), &system);
|
||||
//m_manager->ProcessSentence();
|
||||
m_manager = new Manager(*m_sentence, staticData.GetSearchAlgorithm(), &system);
|
||||
m_manager->ProcessSentence();
|
||||
|
||||
// Add the bleu feature
|
||||
m_bleuScoreFeature = new BleuScoreFeature(useScaledReference, scaleByInputLength, BPfactor, historySmoothing);
|
||||
@ -107,14 +104,13 @@ namespace Mira {
|
||||
bool ignoreUWeight,
|
||||
size_t rank)
|
||||
{
|
||||
StaticData &staticData = StaticData::InstanceNonConst();
|
||||
StaticData &staticData = StaticData::InstanceNonConst();
|
||||
|
||||
m_sentence = new Sentence(Input);
|
||||
m_sentence = new Sentence(Input);
|
||||
stringstream in(source + "\n");
|
||||
const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder();
|
||||
m_sentence->Read(in,inputFactorOrder);
|
||||
const TranslationSystem& system = staticData.GetTranslationSystem
|
||||
(TranslationSystem::DEFAULT);
|
||||
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
|
||||
|
||||
// set the weight for the bleu feature
|
||||
ostringstream bleuWeightStr;
|
||||
@ -212,5 +208,19 @@ namespace Mira {
|
||||
void MosesDecoder::updateHistory(const vector< vector< const Word*> >& words, vector<size_t>& sourceLengths, vector<size_t>& ref_ids) {
|
||||
m_bleuScoreFeature->UpdateHistory(words, sourceLengths, ref_ids);
|
||||
}
|
||||
|
||||
void MosesDecoder::calculateBleuOfCorpus(const vector< vector< const Word*> >& words, vector<size_t>& ref_ids, size_t epoch) {
|
||||
vector<float> bleu = m_bleuScoreFeature->CalculateBleuOfCorpus(words, ref_ids);
|
||||
cerr << "\nBleu after epoch " << epoch << ": ";
|
||||
if (bleu.size() > 0) {
|
||||
cerr << "\nBLEU: " << bleu[4]*100 << ", "
|
||||
<< bleu[3]*100 << "/" << bleu[2]*100 << "/" << bleu[1]*100 << "/" << bleu[0]*100 << " "
|
||||
<< "(BP=" << bleu[5] << ", " << "ratio=" << bleu[6] << ", "
|
||||
<< "hyp_len=" << bleu[7] << ", ref_len=" << bleu[8] << ")" << endl;
|
||||
}
|
||||
else {
|
||||
cerr << "BLEU: 0" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ namespace Mira {
|
||||
* Initialise moses (including StaticData) using the given ini file and debuglevel, passing through any
|
||||
* other command line arguments.
|
||||
**/
|
||||
void initMoses(const std::string& inifile, int debuglevel, int argc=0, char** argv=NULL );
|
||||
void initMoses(const std::string& inifile, int debuglevel, int argc, std::vector<std::string> decoder_params);
|
||||
|
||||
|
||||
/**
|
||||
@ -67,6 +67,7 @@ class MosesDecoder {
|
||||
size_t getCurrentInputLength();
|
||||
void updateHistory(const std::vector<const Moses::Word*>& words);
|
||||
void updateHistory(const std::vector< std::vector< const Moses::Word*> >& words, std::vector<size_t>& sourceLengths, std::vector<size_t>& ref_ids);
|
||||
void calculateBleuOfCorpus(const std::vector< std::vector< const Moses::Word*> >& words, std::vector<size_t>& ref_ids, size_t epoch);
|
||||
Moses::ScoreComponentCollection getWeights();
|
||||
void setWeights(const Moses::ScoreComponentCollection& weights);
|
||||
void cleanup();
|
||||
|
312
mira/Main.cpp
312
mira/Main.cpp
@ -24,6 +24,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
#include <vector>
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
#ifdef MPI_ENABLE
|
||||
#include <boost/mpi.hpp>
|
||||
namespace mpi = boost::mpi;
|
||||
@ -83,6 +85,8 @@ int main(int argc, char** argv) {
|
||||
size_t weightDumpFrequency;
|
||||
string weightDumpStem;
|
||||
float marginScaleFactor;
|
||||
float marginScaleFactorStep;
|
||||
float marginScaleFactorMin;
|
||||
bool weightedLossFunction;
|
||||
size_t n;
|
||||
size_t batchSize;
|
||||
@ -94,6 +98,8 @@ int main(int argc, char** argv) {
|
||||
bool scaleByInputLength;
|
||||
float BPfactor;
|
||||
float slack;
|
||||
float slack_step;
|
||||
float slack_max;
|
||||
size_t maxNumberOracles;
|
||||
bool accumulateMostViolatedConstraints;
|
||||
bool pastAndCurrentConstraints;
|
||||
@ -104,21 +110,26 @@ int main(int argc, char** argv) {
|
||||
size_t baseOfLog;
|
||||
float clipping;
|
||||
bool fixedClipping;
|
||||
string decoder_settings;
|
||||
float min_weight_change;
|
||||
bool devBleu;
|
||||
po::options_description desc("Allowed options");
|
||||
desc.add_options()
|
||||
("help",po::value( &help )->zero_tokens()->default_value(false), "Print this help message and exit")
|
||||
("config,f",po::value<string>(&mosesConfigFile),"Moses ini file")
|
||||
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
|
||||
("input-file,i",po::value<string>(&inputFile),"Input file containing tokenised source")
|
||||
("reference-files,r", po::value<vector<string> >(&referenceFiles), "Reference translation files for training")
|
||||
("epochs,e", po::value<size_t>(&epochs)->default_value(1), "Number of epochs")
|
||||
("learner,l", po::value<string>(&learner)->default_value("mira"), "Learning algorithm")
|
||||
("mix-frequency", po::value<size_t>(&mixFrequency)->default_value(1), "How often per epoch to mix weights, when using mpi")
|
||||
("weight-dump-stem", po::value<string>(&weightDumpStem)->default_value("weights"), "Stem of filename to use for dumping weights")
|
||||
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights")
|
||||
("shuffle", po::value<bool>(&shuffle)->default_value(false), "Shuffle input sentences before processing")
|
||||
("help",po::value( &help )->zero_tokens()->default_value(false), "Print this help message and exit")
|
||||
("config,f",po::value<string>(&mosesConfigFile),"Moses ini file")
|
||||
("verbosity,v", po::value<int>(&verbosity)->default_value(0), "Verbosity level")
|
||||
("input-file,i",po::value<string>(&inputFile),"Input file containing tokenised source")
|
||||
("reference-files,r", po::value<vector<string> >(&referenceFiles), "Reference translation files for training")
|
||||
("epochs,e", po::value<size_t>(&epochs)->default_value(1), "Number of epochs")
|
||||
("learner,l", po::value<string>(&learner)->default_value("mira"), "Learning algorithm")
|
||||
("mix-frequency", po::value<size_t>(&mixFrequency)->default_value(1), "How often per epoch to mix weights, when using mpi")
|
||||
("weight-dump-stem", po::value<string>(&weightDumpStem)->default_value("weights"), "Stem of filename to use for dumping weights")
|
||||
("weight-dump-frequency", po::value<size_t>(&weightDumpFrequency)->default_value(1), "How often per epoch to dump weights")
|
||||
("shuffle", po::value<bool>(&shuffle)->default_value(false), "Shuffle input sentences before processing")
|
||||
("hildreth", po::value<bool>(&hildreth)->default_value(true), "Use Hildreth's optimisation algorithm")
|
||||
("margin-scale-factor,m", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
|
||||
("msf", po::value<float>(&marginScaleFactor)->default_value(1.0), "Margin scale factor, regularises the update by scaling the enforced margin")
|
||||
("msf-step", po::value<float>(&marginScaleFactorStep)->default_value(0), "Decrease margin scale factor iteratively by the value provided")
|
||||
("msf-min", po::value<float>(&marginScaleFactorMin)->default_value(1.0), "Minimum value that margin is scaled by")
|
||||
("weighted-loss-function", po::value<bool>(&weightedLossFunction)->default_value(false), "Weight the loss of a hypothesis by its Bleu score")
|
||||
("nbest,n", po::value<size_t>(&n)->default_value(10), "Number of translations in nbest list")
|
||||
("batch-size,b", po::value<size_t>(&batchSize)->default_value(1), "Size of batch that is send to optimiser for weight adjustments")
|
||||
@ -129,7 +140,9 @@ int main(int argc, char** argv) {
|
||||
("use-scaled-reference", po::value<bool>(&useScaledReference)->default_value(true), "Use scaled reference length for comparing target and reference length of phrases")
|
||||
("scale-by-input-length", po::value<bool>(&scaleByInputLength)->default_value(true), "Scale the BLEU score by a history of the input lengths")
|
||||
("BP-factor", po::value<float>(&BPfactor)->default_value(1.0), "Increase penalty for short translations")
|
||||
("slack", po::value<float>(&slack)->default_value(0), "Use slack in optimization problem")
|
||||
("slack", po::value<float>(&slack)->default_value(0), "Use slack in optimizer")
|
||||
("slack-step", po::value<float>(&slack_step)->default_value(0), "Increase slack from epoch to epoch by the value provided")
|
||||
("slack-max", po::value<float>(&slack_max)->default_value(0), "Maximum slack used")
|
||||
("max-number-oracles", po::value<size_t>(&maxNumberOracles)->default_value(1), "Set a maximum number of oracles to use per example")
|
||||
("accumulate-most-violated-constraints", po::value<bool>(&accumulateMostViolatedConstraints)->default_value(false), "Accumulate most violated constraint per example")
|
||||
("past-and-current-constraints", po::value<bool>(&pastAndCurrentConstraints)->default_value(false), "Accumulate most violated constraint per example and use them along all current constraints")
|
||||
@ -138,8 +151,11 @@ int main(int argc, char** argv) {
|
||||
("ignore-weird-updates", po::value<bool>(&ignoreWeirdUpdates)->default_value(false), "Ignore updates that increase number of violated constraints AND increase the error")
|
||||
("log-feature-values", po::value<bool>(&logFeatureValues)->default_value(false), "Take log of feature values according to the given base.")
|
||||
("base-of-log", po::value<size_t>(&baseOfLog)->default_value(10), "Base for log-ing feature values")
|
||||
("clipping", po::value<float>(&clipping)->default_value(0.01f), "Set a threshold to regularise updates")
|
||||
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold");
|
||||
("clipping", po::value<float>(&clipping)->default_value(0.01), "Set a threshold to regularise updates")
|
||||
("fixed-clipping", po::value<bool>(&fixedClipping)->default_value(false), "Use a fixed clipping threshold")
|
||||
("decoder-settings", po::value<string>(&decoder_settings)->default_value(""), "Decoder settings for tuning runs")
|
||||
("min-weight-change", po::value<float>(&min_weight_change)->default_value(0.01), "Set minimum weight change for stopping criterion")
|
||||
("dev-bleu", po::value<bool>(&devBleu)->default_value(false), "Compute BLEU score of oracle translations of the whole tuning set");
|
||||
|
||||
po::options_description cmdline_options;
|
||||
cmdline_options.add(desc);
|
||||
@ -191,7 +207,9 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
|
||||
// initialise Moses
|
||||
initMoses(mosesConfigFile, verbosity);//, argc, argv);
|
||||
vector<string> decoder_params;
|
||||
boost::split(decoder_params, decoder_settings, boost::is_any_of("\t "));
|
||||
initMoses(mosesConfigFile, verbosity, decoder_params.size(), decoder_params);
|
||||
MosesDecoder* decoder = new MosesDecoder(referenceSentences, useScaledReference, scaleByInputLength, BPfactor, historySmoothing);
|
||||
ScoreComponentCollection startWeights = decoder->getWeights();
|
||||
startWeights.L1Normalise();
|
||||
@ -260,8 +278,8 @@ int main(int argc, char** argv) {
|
||||
|
||||
//Main loop:
|
||||
ScoreComponentCollection cumulativeWeights; // collect weights per epoch to produce an average
|
||||
size_t iterations = 0;
|
||||
size_t iterationsThisEpoch = 0;
|
||||
size_t weightChanges = 0;
|
||||
size_t weightChangesThisEpoch = 0;
|
||||
|
||||
time_t now = time(0); // get current time
|
||||
struct tm* tm = localtime(&now); // get struct filled out
|
||||
@ -270,18 +288,17 @@ int main(int argc, char** argv) {
|
||||
|
||||
// the result of accumulating and averaging weights over one epoch and possibly several processes
|
||||
ScoreComponentCollection averageTotalWeights;
|
||||
|
||||
ScoreComponentCollection averageTotalWeightsCurrent;
|
||||
ScoreComponentCollection averageTotalWeightsPrevious;
|
||||
ScoreComponentCollection averageTotalWeightsBeforePrevious;
|
||||
|
||||
// TODO: scaling of feature values for probabilistic features
|
||||
vector< ScoreComponentCollection> list_of_delta_h; // collect delta_h and loss for all examples of an epoch
|
||||
vector< float> list_of_losses;
|
||||
// print initial weights
|
||||
cerr << "weights: " << decoder->getWeights() << endl;
|
||||
|
||||
for (size_t epoch = 0; epoch < epochs; ++epoch) {
|
||||
cerr << "\nEpoch " << epoch << endl;
|
||||
weightChangesThisEpoch = 0;
|
||||
// Sum up weights over one epoch, final average uses weights from last epoch
|
||||
iterationsThisEpoch = 0;
|
||||
if (!accumulateWeights) {
|
||||
cumulativeWeights.ZeroAll();
|
||||
}
|
||||
@ -289,6 +306,10 @@ int main(int argc, char** argv) {
|
||||
// number of weight dumps this epoch
|
||||
size_t weightEpochDump = 0;
|
||||
|
||||
// collect all oracles for dev set
|
||||
vector< vector< const Word*> > allOracles;
|
||||
vector<size_t> all_ref_ids;
|
||||
|
||||
size_t shardPosition = 0;
|
||||
vector<size_t>::const_iterator sid = shard.begin();
|
||||
while (sid != shard.end()) {
|
||||
@ -330,6 +351,7 @@ int main(int argc, char** argv) {
|
||||
rank);
|
||||
inputLengths.push_back(decoder->getCurrentInputLength());
|
||||
ref_ids.push_back(*sid);
|
||||
all_ref_ids.push_back(*sid);
|
||||
decoder->cleanup();
|
||||
cerr << "Rank " << rank << ", model length: " << bestModel.size() << " Bleu: " << bleuScores[batchPosition][0] << endl;
|
||||
|
||||
@ -338,8 +360,8 @@ int main(int argc, char** argv) {
|
||||
size_t oraclePos = featureValues[batchPosition].size();
|
||||
oraclePositions.push_back(oraclePos);
|
||||
vector<const Word*> oracle = decoder->getNBest(input,
|
||||
*sid,
|
||||
n,
|
||||
*sid,
|
||||
n,
|
||||
1.0,
|
||||
1.0,
|
||||
featureValues[batchPosition],
|
||||
@ -350,6 +372,7 @@ int main(int argc, char** argv) {
|
||||
rank);
|
||||
decoder->cleanup();
|
||||
oracles.push_back(oracle);
|
||||
allOracles.push_back(oracle);
|
||||
cerr << "Rank " << rank << ", oracle length: " << oracle.size() << " Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
|
||||
oracleFeatureValues.push_back(featureValues[batchPosition][oraclePos]);
|
||||
@ -380,51 +403,63 @@ int main(int argc, char** argv) {
|
||||
delete fear[i];
|
||||
}
|
||||
|
||||
cerr << "\nRank " << rank << ", sentence " << *sid << ", Bleu: " << bleuScores[batchPosition][oraclePos] << endl;
|
||||
|
||||
// next input sentence
|
||||
++sid;
|
||||
++actualBatchSize;
|
||||
++shardPosition;
|
||||
}
|
||||
|
||||
// Set loss for each sentence as BLEU(oracle) - BLEU(hypothesis)
|
||||
vector< vector<float> > losses(actualBatchSize);
|
||||
for (size_t batchPosition = 0; batchPosition < actualBatchSize; ++batchPosition) {
|
||||
for (size_t j = 0; j < bleuScores[batchPosition].size(); ++j) {
|
||||
losses[batchPosition].push_back(oracleBleuScores[batchPosition] - bleuScores[batchPosition][j]);
|
||||
}
|
||||
}
|
||||
// Set loss for each sentence as BLEU(oracle) - BLEU(hypothesis)
|
||||
vector< vector<float> > losses(actualBatchSize);
|
||||
for (size_t batchPosition = 0; batchPosition < actualBatchSize; ++batchPosition) {
|
||||
for (size_t j = 0; j < bleuScores[batchPosition].size(); ++j) {
|
||||
losses[batchPosition].push_back(oracleBleuScores[batchPosition] - bleuScores[batchPosition][j]);
|
||||
}
|
||||
}
|
||||
|
||||
// get weight vector and set weight for bleu feature to 0
|
||||
ScoreComponentCollection mosesWeights = decoder->getWeights();
|
||||
const vector<const ScoreProducer*> featureFunctions = StaticData::Instance().GetTranslationSystem (TranslationSystem::DEFAULT).GetFeatureFunctions();
|
||||
mosesWeights.Assign(featureFunctions.back(), 0);
|
||||
// get weight vector and set weight for bleu feature to 0
|
||||
ScoreComponentCollection mosesWeights = decoder->getWeights();
|
||||
const vector<const ScoreProducer*> featureFunctions = StaticData::Instance().GetTranslationSystem (TranslationSystem::DEFAULT).GetFeatureFunctions();
|
||||
mosesWeights.Assign(featureFunctions.back(), 0);
|
||||
|
||||
if (!hildreth && typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
((MiraOptimiser*)optimiser)->setOracleIndices(oraclePositions);
|
||||
}
|
||||
if (!hildreth && typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
((MiraOptimiser*)optimiser)->setOracleIndices(oraclePositions);
|
||||
}
|
||||
|
||||
if (logFeatureValues) {
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
featureValues[i][j].ApplyLog(baseOfLog);
|
||||
}
|
||||
if (logFeatureValues) {
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
featureValues[i][j].ApplyLog(baseOfLog);
|
||||
}
|
||||
|
||||
oracleFeatureValues[i].ApplyLog(baseOfLog);
|
||||
}
|
||||
}
|
||||
oracleFeatureValues[i].ApplyLog(baseOfLog);
|
||||
}
|
||||
}
|
||||
|
||||
// run optimiser on batch
|
||||
cerr << "\nRank " << rank << ", run optimiser.." << endl;
|
||||
ScoreComponentCollection oldWeights(mosesWeights);
|
||||
int constraintChange = optimiser->updateWeights(mosesWeights, featureValues, losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids);
|
||||
cerr << "\nRank " << rank << ", run optimiser.." << endl;
|
||||
ScoreComponentCollection oldWeights(mosesWeights);
|
||||
int constraintChange = optimiser->updateWeights(mosesWeights, featureValues, losses, bleuScores, oracleFeatureValues, oracleBleuScores, ref_ids);
|
||||
|
||||
// normalise Moses weights
|
||||
mosesWeights.L1Normalise();
|
||||
mosesWeights.L1Normalise();
|
||||
|
||||
// compute difference to old weights
|
||||
// print weights and features values
|
||||
cerr << "\nRank " << rank << ", weights (normalised): " << mosesWeights << endl;
|
||||
/* cerr << "Rank " << rank << ", feature values: " << endl;
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
cerr << featureValues[i][j].Size() << ": " << featureValues[i][j] << endl;
|
||||
}
|
||||
}
|
||||
cerr << endl;*/
|
||||
|
||||
// compute difference to old weights
|
||||
ScoreComponentCollection weightDifference(mosesWeights);
|
||||
weightDifference.MinusEquals(oldWeights);
|
||||
//cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
cerr << "Rank " << rank << ", weight difference: " << weightDifference << endl;
|
||||
|
||||
// update history (for approximate document Bleu)
|
||||
for (size_t i = 0; i < oracles.size(); ++i) {
|
||||
@ -432,82 +467,42 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
decoder->updateHistory(oracles, inputLengths, ref_ids);
|
||||
|
||||
// clean up oracle translations after updating history
|
||||
for (size_t i = 0; i < oracles.size(); ++i) {
|
||||
for (size_t j = 0; j < oracles[i].size(); ++j) {
|
||||
delete oracles[i][j];
|
||||
}
|
||||
if (!devBleu) {
|
||||
// clean up oracle translations after updating history
|
||||
for (size_t i = 0; i < oracles.size(); ++i) {
|
||||
for (size_t j = 0; j < oracles[i].size(); ++j) {
|
||||
delete oracles[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanity check: compare margin created by old weights against new weights
|
||||
float lossMinusMargin_old = 0;
|
||||
float lossMinusMargin_new = 0;
|
||||
for (size_t batchPosition = 0; batchPosition < actualBatchSize; ++batchPosition) {
|
||||
for (size_t j = 0; j < featureValues[batchPosition].size(); ++j) {
|
||||
ScoreComponentCollection featureDiff(oracleFeatureValues[batchPosition]);
|
||||
featureDiff.MinusEquals(featureValues[batchPosition][j]);
|
||||
// set and accumulate weights
|
||||
decoder->setWeights(mosesWeights);
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
|
||||
// old weights
|
||||
float margin = featureDiff.InnerProduct(oldWeights);
|
||||
lossMinusMargin_old += (losses[batchPosition][j] - margin);
|
||||
// new weights
|
||||
margin = featureDiff.InnerProduct(mosesWeights);
|
||||
lossMinusMargin_new += (losses[batchPosition][j] - margin);
|
||||
|
||||
// now collect translations of first epoch only
|
||||
if (rank == 0 && epoch == 0) {
|
||||
list_of_delta_h.push_back(featureDiff);
|
||||
list_of_losses.push_back(losses[batchPosition][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "\nRank " << rank << ", constraint change: " << constraintChange << endl;
|
||||
cerr << "Rank " << rank << ", summed (loss - margin) with old weights: " << lossMinusMargin_old << endl;
|
||||
cerr << "Rank " << rank << ", summed (loss - margin) with new weights: " << lossMinusMargin_new << endl;
|
||||
bool useNewWeights = true;
|
||||
if (lossMinusMargin_new > lossMinusMargin_old) {
|
||||
cerr << "Rank " << rank << ", worsening: " << lossMinusMargin_new - lossMinusMargin_old << endl;
|
||||
|
||||
if (constraintChange < 0) {
|
||||
cerr << "Rank " << rank << ", something is going wrong here.." << endl;
|
||||
if (ignoreWeirdUpdates) {
|
||||
useNewWeights = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (useNewWeights) {
|
||||
decoder->setWeights(mosesWeights);
|
||||
cumulativeWeights.PlusEquals(mosesWeights);
|
||||
}
|
||||
else {
|
||||
cerr << "Ignore new weights, keep old weights.. " << endl;
|
||||
}
|
||||
|
||||
|
||||
++iterations;
|
||||
++iterationsThisEpoch;
|
||||
++weightChanges;
|
||||
++weightChangesThisEpoch;
|
||||
|
||||
// mix weights?
|
||||
#ifdef MPI_ENABLE
|
||||
if (shardPosition % (shard.size() / mixFrequency) == 0) {
|
||||
ScoreComponentCollection averageWeights;
|
||||
//if (rank == 0) {
|
||||
// cerr << "Rank 0, before mixing: " << mosesWeights << endl;
|
||||
//}
|
||||
if (rank == 0) {
|
||||
cerr << "Rank 0, before mixing: " << mosesWeights << endl;
|
||||
}
|
||||
|
||||
//VERBOSE(1, "\nRank: " << rank << " \nBefore mixing: " << mosesWeights << endl);
|
||||
VERBOSE(1, "\nRank: " << rank << " \nBefore mixing: " << mosesWeights << endl);
|
||||
|
||||
// collect all weights in averageWeights and divide by number of processes
|
||||
mpi::reduce(world, mosesWeights, averageWeights, SCCPlus(), 0);
|
||||
if (rank == 0) {
|
||||
averageWeights.DivideEquals(size);
|
||||
//VERBOSE(1, "After mixing: " << averageWeights << endl);
|
||||
//cerr << "Rank 0, after mixing: " << averageWeights << endl;
|
||||
|
||||
// normalise weights after averaging
|
||||
averageWeights.L1Normalise();
|
||||
|
||||
VERBOSE(1, "After mixing (normalised): " << averageWeights << endl);
|
||||
cerr << "Rank 0, after mixing (normalised): " << averageWeights << endl;
|
||||
}
|
||||
|
||||
// broadcast average weights from process 0
|
||||
@ -520,14 +515,14 @@ int main(int argc, char** argv) {
|
||||
// compute average weights per process over iterations
|
||||
ScoreComponentCollection totalWeights(cumulativeWeights);
|
||||
if (accumulateWeights)
|
||||
totalWeights.DivideEquals(iterations);
|
||||
totalWeights.DivideEquals(weightChanges);
|
||||
else
|
||||
totalWeights.DivideEquals(iterationsThisEpoch);
|
||||
totalWeights.DivideEquals(weightChangesThisEpoch);
|
||||
|
||||
//if (rank == 0) {
|
||||
//cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
|
||||
//cerr << "Rank 0, total weights: " << totalWeights << endl;
|
||||
//}
|
||||
if (rank == 0) {
|
||||
cerr << "Rank 0, cumulative weights: " << cumulativeWeights << endl;
|
||||
cerr << "Rank 0, total weights: " << totalWeights << endl;
|
||||
}
|
||||
|
||||
if (weightEpochDump + 1 == weightDumpFrequency){
|
||||
// last weight dump in epoch
|
||||
@ -545,10 +540,10 @@ int main(int argc, char** argv) {
|
||||
#endif
|
||||
|
||||
if (rank == 0 && !weightDumpStem.empty()) {
|
||||
// average and normalise weights
|
||||
// average by number of processes and normalise weights
|
||||
averageTotalWeights.DivideEquals(size);
|
||||
averageTotalWeights.L1Normalise();
|
||||
//cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
cerr << "Rank 0, average total weights (normalised): " << averageTotalWeights << endl;
|
||||
|
||||
ostringstream filename;
|
||||
if (epoch < 10) {
|
||||
@ -567,20 +562,14 @@ int main(int argc, char** argv) {
|
||||
|
||||
if (weightEpochDump + 1 == weightDumpFrequency){
|
||||
// last weight dump in epoch
|
||||
|
||||
// compute summed error
|
||||
float summedError = 0;
|
||||
for (size_t i = 0; i < list_of_delta_h.size(); ++i) {
|
||||
summedError += (list_of_losses[i] - list_of_delta_h[i].InnerProduct(averageTotalWeights));
|
||||
}
|
||||
cerr << "Rank 0, summed error after dumping weights: " << summedError << " (" << list_of_delta_h.size() << " examples)" << endl;
|
||||
|
||||
// compare new average weights with previous weights
|
||||
averageTotalWeightsCurrent = averageTotalWeights;
|
||||
ScoreComponentCollection firstDiff(averageTotalWeightsCurrent);
|
||||
firstDiff.MinusEquals(averageTotalWeightsPrevious);
|
||||
cerr << "Rank 0, weight changes since previous epoch: " << firstDiff << endl;
|
||||
ScoreComponentCollection secondDiff(averageTotalWeightsCurrent);
|
||||
secondDiff.MinusEquals(averageTotalWeightsBeforePrevious);
|
||||
cerr << "Rank 0, weight changes since before previous epoch: " << secondDiff << endl;
|
||||
|
||||
if (!suppressConvergence) {
|
||||
// check whether stopping criterion has been reached
|
||||
@ -591,7 +580,7 @@ int main(int argc, char** argv) {
|
||||
FVector::const_iterator iterator1 = changes1.cbegin();
|
||||
FVector::const_iterator iterator2 = changes2.cbegin();
|
||||
while (iterator1 != changes1.cend()) {
|
||||
if ((*iterator1).second >= 0.01 || (*iterator2).second >= 0.01) {
|
||||
if (abs((*iterator1).second) >= min_weight_change || abs((*iterator2).second) >= min_weight_change) {
|
||||
reached = false;
|
||||
break;
|
||||
}
|
||||
@ -602,9 +591,8 @@ int main(int argc, char** argv) {
|
||||
|
||||
if (reached) {
|
||||
// stop MIRA
|
||||
cerr << "\nRank 0, stopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl << endl;
|
||||
//cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
|
||||
cerr << "\nRank 0, stopping criterion has been reached after epoch " << epoch << ".. stopping MIRA." << endl;
|
||||
|
||||
ScoreComponentCollection dummy;
|
||||
ostringstream endfilename;
|
||||
endfilename << "stopping";
|
||||
@ -613,6 +601,17 @@ int main(int argc, char** argv) {
|
||||
#ifdef MPI_ENABLE
|
||||
MPI_Abort(MPI_COMM_WORLD, 0);
|
||||
#endif
|
||||
if (devBleu) {
|
||||
// calculate bleu score of all oracle translations of dev set
|
||||
decoder->calculateBleuOfCorpus(allOracles, all_ref_ids, epoch);
|
||||
}
|
||||
if (marginScaleFactorStep > 0) {
|
||||
cerr << "margin scale factor: " << marginScaleFactor << endl;
|
||||
}
|
||||
if (slack_step > 0) {
|
||||
cerr << "slack: " << slack << endl;
|
||||
}
|
||||
|
||||
goto end;
|
||||
}
|
||||
}
|
||||
@ -623,8 +622,47 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
//list_of_delta_h.clear();
|
||||
//list_of_losses.clear();
|
||||
if (devBleu) {
|
||||
// calculate bleu score of all oracle translations of dev set
|
||||
decoder->calculateBleuOfCorpus(allOracles, all_ref_ids, epoch);
|
||||
|
||||
// clean up oracle translations
|
||||
for (size_t i = 0; i < allOracles.size(); ++i) {
|
||||
for (size_t j = 0; j < allOracles[i].size(); ++j) {
|
||||
delete allOracles[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if using flexible margin scale factor, increase scaling (decrease value) for next epoch
|
||||
if (marginScaleFactorStep > 0) {
|
||||
if (marginScaleFactor - marginScaleFactorStep >= marginScaleFactorMin) {
|
||||
if (typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
cerr << "old margin scale factor: " << marginScaleFactor << endl;
|
||||
marginScaleFactor -= marginScaleFactorStep;
|
||||
cerr << "new margin scale factor: " << marginScaleFactor << endl;
|
||||
((MiraOptimiser*)optimiser)->setMarginScaleFactor(marginScaleFactor);
|
||||
}
|
||||
}
|
||||
else {
|
||||
cerr << "margin scale factor: " << marginScaleFactor << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// if using flexible slack, increase slack for next epoch
|
||||
if (slack_step > 0) {
|
||||
if (slack + slack_step <= slack_max) {
|
||||
if (typeid(*optimiser) == typeid(MiraOptimiser)) {
|
||||
cerr << "old slack: " << slack << endl;
|
||||
slack += slack_step;
|
||||
cerr << "new slack: " << slack << endl;
|
||||
((MiraOptimiser*)optimiser)->setSlack(slack);
|
||||
}
|
||||
}
|
||||
else {
|
||||
cerr << "slack: " << slack << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
end:
|
||||
@ -633,13 +671,9 @@ int main(int argc, char** argv) {
|
||||
MPI_Finalize();
|
||||
#endif
|
||||
|
||||
//if (rank == 0) {
|
||||
// cerr << "Rank 0, average total weights: " << averageTotalWeights << endl;
|
||||
//}
|
||||
|
||||
now = time(0); // get current time
|
||||
tm = localtime(&now); // get struct filled out
|
||||
cerr << "End date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
|
||||
cerr << "\nEnd date/time: " << tm->tm_mon+1 << "/" << tm->tm_mday << "/" << tm->tm_year + 1900
|
||||
<< ", " << tm->tm_hour << ":" << tm->tm_min << ":" << tm->tm_sec << endl;
|
||||
|
||||
delete decoder;
|
||||
|
@ -63,7 +63,8 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
size_t sentenceId = sentenceIds[i];
|
||||
cerr << "Available oracles for source sentence " << sentenceId << ": " << m_oracles[sentenceId].size() << endl;
|
||||
if (m_oracles[sentenceId].size() > 1)
|
||||
cerr << "Available oracles for source sentence " << sentenceId << ": " << m_oracles[sentenceId].size() << endl;
|
||||
for (size_t j = 0; j < featureValues[i].size(); ++j) {
|
||||
// check if optimisation criterion is violated for one hypothesis and the oracle
|
||||
// h(e*) >= h(e_ij) + loss(e_ij)
|
||||
@ -90,6 +91,14 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
addConstraint = false;
|
||||
}
|
||||
|
||||
/*if (modelScoreDiff < loss) {
|
||||
// constraint violated
|
||||
cerr << modelScoreDiff << " < " << loss << endl;
|
||||
}
|
||||
else {
|
||||
cerr << modelScoreDiff << " >= " << loss << endl;
|
||||
}*/
|
||||
|
||||
if (addConstraint) {
|
||||
float lossMarginDistance = loss - modelScoreDiff;
|
||||
|
||||
@ -134,7 +143,7 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
m_featureValueDiffs.push_back(maxViolationfeatureValueDiff);
|
||||
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
|
||||
|
||||
cerr << "Number of constraints passed to optimiser: " << m_featureValueDiffs.size() << endl;
|
||||
// cerr << "Number of constraints passed to optimiser: " << m_featureValueDiffs.size() << endl;
|
||||
if (m_slack != 0) {
|
||||
alphas = Hildreth::optimise(m_featureValueDiffs, m_lossMarginDistances, m_slack);
|
||||
}
|
||||
@ -176,10 +185,28 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
m_lossMarginDistances.push_back(maxViolationLossMarginDistance);
|
||||
}
|
||||
|
||||
//cerr << "Number of violated constraints before optimisation: " << violatedConstraintsBefore << endl;
|
||||
cerr << "Number of constraints passed to optimiser: " << featureValueDiffs.size() << endl;
|
||||
cerr << "Number of violated constraints before optimisation: " << violatedConstraintsBefore << endl;
|
||||
if (featureValueDiffs.size() != 30) {
|
||||
cerr << "Number of constraints passed to optimiser: " << featureValueDiffs.size() << endl;
|
||||
}
|
||||
if (m_slack != 0) {
|
||||
/*cerr << "Feature value diffs (A): " << endl;
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
cerr << featureValueDiffs[k] << endl;
|
||||
}
|
||||
cerr << endl << "Loss - margin (b): " << endl;
|
||||
for (size_t i = 0; i < lossMarginDistances.size(); ++i) {
|
||||
cerr << lossMarginDistances[i] << endl;
|
||||
}
|
||||
cerr << endl;
|
||||
cerr << "Slack: " << m_slack << endl;*/
|
||||
|
||||
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances, m_slack);
|
||||
/*cerr << "Alphas: " << endl;
|
||||
for (size_t i = 0; i < alphas.size(); ++i) {
|
||||
cerr << alphas[i] << endl;
|
||||
}
|
||||
cerr << endl << endl;*/
|
||||
}
|
||||
else {
|
||||
alphas = Hildreth::optimise(featureValueDiffs, lossMarginDistances);
|
||||
@ -189,23 +216,29 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
// * w' = w' + delta * Dh_ij ---> w' = w' + delta * (h(e*) - h(e_ij))
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
// compute update
|
||||
float update = alphas[k];
|
||||
float alpha = alphas[k];
|
||||
if (m_fixedClipping) {
|
||||
if (update > m_c) {
|
||||
update = m_c;
|
||||
if (alpha > m_c) {
|
||||
alpha = m_c;
|
||||
}
|
||||
else if (update < -1 * m_c) {
|
||||
update = -1 * m_c;
|
||||
else if (alpha < -1 * m_c) {
|
||||
alpha = -1 * m_c;
|
||||
}
|
||||
}
|
||||
|
||||
featureValueDiffs[k].MultiplyEquals(update);
|
||||
cerr << "alpha: " << update << endl;
|
||||
featureValueDiffs[k].MultiplyEquals(alpha);
|
||||
//cerr << "alpha: " << alpha << endl;
|
||||
|
||||
// apply update to weight vector
|
||||
currWeights.PlusEquals(featureValueDiffs[k]);
|
||||
}
|
||||
|
||||
/*cerr << "Updates to weight vector: " << endl;
|
||||
for (size_t k = 0; k < featureValueDiffs.size(); ++k) {
|
||||
cerr << featureValueDiffs[k] << endl;
|
||||
}
|
||||
cerr << endl << endl;*/
|
||||
|
||||
// sanity check: how many constraints violated after optimisation?
|
||||
size_t violatedConstraintsAfter = 0;
|
||||
for (size_t i = 0; i < featureValues.size(); ++i) {
|
||||
@ -216,15 +249,16 @@ int MiraOptimiser::updateWeights(ScoreComponentCollection& currWeights,
|
||||
float loss = losses[i][j] * m_marginScaleFactor;
|
||||
if (modelScoreDiff < loss) {
|
||||
++violatedConstraintsAfter;
|
||||
//cerr << modelScoreDiff << " < " << loss << endl;
|
||||
}
|
||||
else {
|
||||
//cerr << modelScoreDiff << " >= " << loss << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//cerr << "Number of violated constraints after optimisation: " << violatedConstraintsAfter << endl;
|
||||
if (violatedConstraintsAfter > violatedConstraintsBefore) {
|
||||
cerr << "Increase: " << violatedConstraintsAfter - violatedConstraintsBefore << endl << endl;
|
||||
}
|
||||
|
||||
int constraintChange = violatedConstraintsBefore - violatedConstraintsAfter;
|
||||
cerr << "Constraint change: " << constraintChange << endl;
|
||||
return violatedConstraintsBefore - violatedConstraintsAfter;
|
||||
}
|
||||
else {
|
||||
|
@ -104,6 +104,14 @@ namespace Mira {
|
||||
void setOracleIndices(std::vector<size_t> oracleIndices) {
|
||||
m_oracleIndices= oracleIndices;
|
||||
}
|
||||
|
||||
void setSlack(float slack) {
|
||||
m_slack = slack;
|
||||
}
|
||||
|
||||
void setMarginScaleFactor(float msf) {
|
||||
m_marginScaleFactor = msf;
|
||||
}
|
||||
|
||||
private:
|
||||
// number of hypotheses used for each nbest list (number of hope, fear, best model translations)
|
||||
|
@ -238,6 +238,51 @@ void BleuScoreFeature::GetNgramMatchCounts(Phrase& phrase,
|
||||
}
|
||||
}
|
||||
|
||||
void BleuScoreFeature::GetClippedNgramMatchesAndCounts(Phrase& phrase,
|
||||
const NGrams& ref_ngram_counts,
|
||||
std::vector< size_t >& ret_counts,
|
||||
std::vector< size_t >& ret_matches,
|
||||
size_t skip_first) const
|
||||
{
|
||||
std::map< Phrase, size_t >::const_iterator ref_ngram_counts_iter;
|
||||
size_t ngram_start_idx, ngram_end_idx;
|
||||
|
||||
std::map<size_t, std::map<Phrase, size_t> > ngram_matches;
|
||||
for (size_t end_idx = skip_first; end_idx < phrase.GetSize(); end_idx++) {
|
||||
for (size_t order = 0; order < BleuScoreState::bleu_order; order++) {
|
||||
if (order > end_idx) break;
|
||||
|
||||
ngram_end_idx = end_idx;
|
||||
ngram_start_idx = end_idx - order;
|
||||
|
||||
Phrase ngram = phrase.GetSubString(WordsRange(ngram_start_idx, ngram_end_idx));
|
||||
string ngramString = ngram.ToString();
|
||||
ret_counts[order]++;
|
||||
|
||||
ref_ngram_counts_iter = ref_ngram_counts.find(ngram);
|
||||
if (ref_ngram_counts_iter != ref_ngram_counts.end()) {
|
||||
ngram_matches[order][ngram]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clip ngram matches
|
||||
for (size_t order = 0; order < BleuScoreState::bleu_order; order++) {
|
||||
std::map<Phrase, size_t>::const_iterator iter;
|
||||
|
||||
// iterate over ngram counts for every ngram order
|
||||
for (iter=ngram_matches[order].begin(); iter != ngram_matches[order].end(); ++iter) {
|
||||
ref_ngram_counts_iter = ref_ngram_counts.find(iter->first);
|
||||
if (iter->second > ref_ngram_counts_iter->second) {
|
||||
ret_matches[order] += ref_ngram_counts_iter->second;
|
||||
}
|
||||
else {
|
||||
ret_matches[order] += iter->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Given a previous state, compute Bleu score for the updated state with an additional target
|
||||
* phrase translated.
|
||||
@ -371,6 +416,94 @@ float BleuScoreFeature::CalculateBleu(BleuScoreState* state) const {
|
||||
return precision;
|
||||
}
|
||||
|
||||
vector<float> BleuScoreFeature::CalculateBleuOfCorpus(const vector< vector< const Word* > >& oracles, const vector<size_t>& ref_ids) {
|
||||
// get ngram matches and counts for all oracle sentences and their references
|
||||
vector<size_t> sumOfClippedNgramMatches(BleuScoreState::bleu_order);
|
||||
vector<size_t> sumOfNgramCounts(BleuScoreState::bleu_order);
|
||||
size_t ref_length = 0;
|
||||
size_t target_length = 0;
|
||||
|
||||
for (size_t batchPosition = 0; batchPosition < oracles.size(); ++batchPosition){
|
||||
Phrase phrase(Output, oracles[batchPosition]);
|
||||
size_t ref_id = ref_ids[batchPosition];
|
||||
size_t cur_ref_length = m_refs[ref_id].first;
|
||||
NGrams cur_ref_ngrams = m_refs[ref_id].second;
|
||||
|
||||
ref_length += cur_ref_length;
|
||||
target_length += oracles[batchPosition].size();
|
||||
|
||||
std::vector< size_t > ngram_counts(BleuScoreState::bleu_order);
|
||||
std::vector< size_t > clipped_ngram_matches(BleuScoreState::bleu_order);
|
||||
GetClippedNgramMatchesAndCounts(phrase, cur_ref_ngrams, ngram_counts, clipped_ngram_matches, 0);
|
||||
|
||||
// add clipped ngram matches and ngram counts to corpus sums
|
||||
for (size_t i = 0; i < BleuScoreState::bleu_order; i++) {
|
||||
sumOfClippedNgramMatches[i] += clipped_ngram_matches[i];
|
||||
sumOfNgramCounts[i] += ngram_counts[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (!sumOfNgramCounts[0]) {
|
||||
vector<float> empty(0);
|
||||
return empty;
|
||||
}
|
||||
if (!sumOfClippedNgramMatches[0]) {
|
||||
vector<float> empty(0);
|
||||
return empty; // if we have no unigram matches, score should be 0
|
||||
}
|
||||
|
||||
// calculate bleu score
|
||||
float precision = 1.0;
|
||||
float smoothed_count, smoothed_matches;
|
||||
|
||||
vector<float> bleu;
|
||||
// Calculate geometric mean of modified ngram precisions
|
||||
// BLEU = BP * exp(SUM_1_4 1/4 * log p_n)
|
||||
// = BP * 4th root(PRODUCT_1_4 p_n)
|
||||
for (size_t i = 0; i < BleuScoreState::bleu_order; i++) {
|
||||
if (sumOfNgramCounts[i]) {
|
||||
smoothed_matches = sumOfClippedNgramMatches[i];
|
||||
smoothed_count = sumOfNgramCounts[i];
|
||||
if (i > 0) {
|
||||
// smoothing for all n > 1
|
||||
smoothed_matches += 1;
|
||||
smoothed_count += 1;
|
||||
}
|
||||
|
||||
precision *= smoothed_matches / smoothed_count;
|
||||
bleu.push_back(smoothed_matches / smoothed_count);
|
||||
}
|
||||
else {
|
||||
cerr << "no counts for order " << i+1 << endl;
|
||||
}
|
||||
}
|
||||
|
||||
// take geometric mean
|
||||
precision = pow(precision, (float)1/4);
|
||||
|
||||
// Apply brevity penalty if applicable.
|
||||
// BP = 1 if c > r
|
||||
// BP = e^(1- r/c)) if c <= r
|
||||
// where
|
||||
// c: length of the candidate translation
|
||||
// r: effective reference length (sum of best match lengths for each candidate sentence)
|
||||
float BP;
|
||||
if (target_length < ref_length) {
|
||||
precision *= exp(1 - (1.0*ref_length/target_length));
|
||||
BP = exp(1 - (1.0*ref_length/target_length));
|
||||
}
|
||||
else {
|
||||
BP = 1.0;
|
||||
}
|
||||
|
||||
bleu.push_back(precision);
|
||||
bleu.push_back(BP);
|
||||
bleu.push_back(1.0*target_length/ref_length);
|
||||
bleu.push_back(target_length);
|
||||
bleu.push_back(ref_length);
|
||||
return bleu;
|
||||
}
|
||||
|
||||
const FFState* BleuScoreFeature::EmptyHypothesisState(const InputType& input) const
|
||||
{
|
||||
return new BleuScoreState();
|
||||
|
@ -72,11 +72,17 @@ public:
|
||||
std::vector< size_t >&,
|
||||
std::vector< size_t >&,
|
||||
size_t skip = 0) const;
|
||||
void GetClippedNgramMatchesAndCounts(Phrase&,
|
||||
const NGrams&,
|
||||
std::vector< size_t >&,
|
||||
std::vector< size_t >&,
|
||||
size_t skip = 0) const;
|
||||
|
||||
FFState* Evaluate( const Hypothesis& cur_hypo,
|
||||
const FFState* prev_state,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
float CalculateBleu(BleuScoreState*) const;
|
||||
std::vector<float> CalculateBleuOfCorpus(const std::vector< std::vector< const Word* > >& hypos, const std::vector<size_t>& ref_ids);
|
||||
const FFState* EmptyHypothesisState(const InputType&) const;
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user