added pairwise ranked optimization (PRO) as proposd by [Hopkins&May,2011], just use switch --pairwise-ranked

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4106 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
phkoehn 2011-08-03 17:00:17 +00:00
parent 579d8b0760
commit 36db0ffe48
4 changed files with 170 additions and 16 deletions

View File

@ -103,3 +103,109 @@ void Data::loadnbest(const std::string &file)
inp.close();
}
// really not the right place...
float sentenceLevelBleuPlusOne( ScoreStats &stats ) {
float logbleu = 0.0;
const unsigned int bleu_order = 4;
for (unsigned int j=0; j<bleu_order; j++) {
//cerr << (stats.get(2*j)+1) << "/" << (stats.get(2*j+1)+1) << " ";
logbleu += log(stats.get(2*j)+1) - log(stats.get(2*j+1)+1);
}
logbleu /= bleu_order;
float brevity = 1.0 - (float)stats.get(bleu_order*2)/stats.get(1);
if (brevity < 0.0) {
logbleu += brevity;
}
//cerr << brevity << " -> " << exp(logbleu) << endl;
return exp(logbleu);
}
class SampledPair {
private:
unsigned int translation1;
unsigned int translation2;
float scoreDiff;
public:
SampledPair( unsigned int t1, unsigned int t2, float diff ) {
if (diff > 0) {
translation1 = t1;
translation2 = t2;
scoreDiff = diff;
}
else {
translation1 = t2;
translation2 = t1;
scoreDiff = -diff;
}
}
float getDiff() { return scoreDiff; }
unsigned int getTranslation1() { return translation1; }
unsigned int getTranslation2() { return translation2; }
};
void Data::sample_ranked_pairs( const std::string &rankedpairfile ) {
cout << "Sampling ranked pairs." << endl;
ofstream *outFile = new ofstream();
outFile->open( rankedpairfile.c_str() );
ostream *out = outFile;
const unsigned int n_samplings = 5000;
const unsigned int n_samples = 50;
const float min_diff = 0.05;
// loop over all sentences
for(unsigned int S=0; S<featdata->size(); S++) {
unsigned int n_translations = featdata->get(S).size();
// sample a fixed number of times
vector< SampledPair* > samples;
vector< float > scores;
for(unsigned int i=0; i<n_samplings; i++) {
unsigned int translation1 = rand() % n_translations;
float bleu1 = sentenceLevelBleuPlusOne(scoredata->get(S,translation1));
unsigned int translation2 = rand() % n_translations;
float bleu2 = sentenceLevelBleuPlusOne(scoredata->get(S,translation2));
if (abs(bleu1-bleu2) < min_diff)
continue;
samples.push_back( new SampledPair( translation1, translation2, bleu1-bleu2) );
scores.push_back( 1.0 - abs(bleu1-bleu2) );
}
//cerr << "sampled " << samples.size() << " pairs\n";
float min_diff = -1.0;
if (samples.size() > n_samples) {
nth_element(scores.begin(), scores.begin()+(n_samples-1), scores.end());
min_diff = 0.99999-scores[n_samples-1];
//cerr << "min_diff = " << min_diff << endl;
}
unsigned int collected = 0;
for(unsigned int i=0; i<samples.size() && collected < n_samples; i++) {
if (samples[i]->getDiff() >= min_diff) {
collected++;
FeatureStats &f1 = featdata->get(S,samples[i]->getTranslation1());
FeatureStats &f2 = featdata->get(S,samples[i]->getTranslation2());
*out << "1";
for(unsigned int j=0; j<f1.size(); j++)
if (abs(f1.get(j)-f2.get(j)) > 0.00001)
*out << " F" << j << " " << (f1.get(j)-f2.get(j));
*out << endl;
*out << "0";
for(unsigned int j=0; j<f1.size(); j++)
if (abs(f1.get(j)-f2.get(j)) > 0.00001)
*out << " F" << j << " " << (f2.get(j)-f1.get(j));
*out << endl;
}
delete samples[i];
}
//cerr << "collected " << collected << endl;
}
out->flush();
outFile->close();
}

View File

@ -85,9 +85,12 @@ public:
inline std::string getFeatureName(size_t idx) {
return featdata->getFeatureName(idx);
};
inline size_t getFeatureIndex(const std::string& name) {
return featdata->getFeatureIndex(name);
};
void sample_ranked_pairs( const std::string &rankedPairFile );
};

View File

@ -31,11 +31,12 @@ using namespace std;
void usage(void)
{
cerr<<"usage: mert -d <dimensions> (mandatory )"<<endl;
cerr<<"[-n retry ntimes (default 1)]"<<endl;
cerr<<"[-m number of random directions in powell (default 0)]"<<endl;
cerr<<"[-o\tthe indexes to optimize(default all)]"<<endl;
cerr<<"[-t\tthe optimizer(default powell)]"<<endl;
cerr<<"[-r\tthe random seed (defaults to system clock)"<<endl;
cerr<<"[-n] retry ntimes (default 1)"<<endl;
cerr<<"[-m] number of random directions in powell (default 0)"<<endl;
cerr<<"[-o] the indexes to optimize(default all)"<<endl;
cerr<<"[-t] the optimizer(default powell)"<<endl;
cerr<<"[-r] the random seed (defaults to system clock)"<<endl;
cerr<<"[-p] only create data for paired ranked optimizer"<<endl;
cerr<<"[--sctype|-s] the scorer type (default BLEU)"<<endl;
cerr<<"[--scconfig|-c] configuration string passed to scorer"<<endl;
cerr<<"[--scfile|-S] comma separated list of scorer data files (default score.data)"<<endl;
@ -52,6 +53,7 @@ static struct option long_options[] = {
{"nrandom",1,0,'m'},
{"rseed",required_argument,0,'r'},
{"optimize",1,0,'o'},
{"pro",required_argument,0,'p'},
{"type",1,0,'t'},
{"sctype",1,0,'s'},
{"scconfig",required_argument,0,'c'},
@ -87,6 +89,7 @@ int main (int argc, char **argv)
string scorerfile("statscore.data");
string featurefile("features.data");
string initfile("init.opt");
string pairedrankfile("");
string tooptimizestr("");
vector<unsigned> tooptimize;
@ -95,11 +98,14 @@ int main (int argc, char **argv)
vector<parameter_t> max;
//note: those mins and max are the bound for the starting points of the algorithm, not strict bound on the result!
while ((c=getopt_long (argc, argv, "o:r:d:n:m:t:s:S:F:v:", long_options, &option_index)) != -1) {
while ((c=getopt_long (argc, argv, "o:r:d:n:m:t:s:S:F:v:p:", long_options, &option_index)) != -1) {
switch (c) {
case 'o':
tooptimizestr = string(optarg);
break;
case 'p':
pairedrankfile = string(optarg);
break;
case 'd':
pdim = strtol(optarg, NULL, 10);
break;
@ -250,6 +256,12 @@ int main (int argc, char **argv)
}
}
if (pairedrankfile.compare("") != 0) {
D.sample_ranked_pairs(pairedrankfile);
PrintUserTime("Stopping...");
exit(0);
}
Optimizer *O=OptimizerFactory::BuildOptimizer(pdim,tooptimize,start_list[0],type,nrandom);
O->SetScorer(TheScorer);
O->SetFData(D.getFeatureData());

View File

@ -118,6 +118,7 @@ my $___PREDICTABLE_SEEDS = 0;
my $___START_WITH_HISTORIC_BESTS = 0; # use best settings from all previous iterations as starting points [Foster&Kuhn,2009]
my $___RANDOM_DIRECTIONS = 0; # search in random directions only
my $___NUM_RANDOM_DIRECTIONS = 0; # number of random directions, also works with default optimizer [Cer&al.,2008]
my $___PAIRWISE_RANKED_OPTIMIZER = 0; # use Hopkins&May[2011]
# Parameter for effective reference length when computing BLEU score
# Default is to use shortest reference
@ -203,6 +204,7 @@ GetOptions(
"prev-aggregate-nbestlist=i" => \$prev_aggregate_nbl_size, #number of previous step to consider when loading data (default =-1, i.e. all previous)
"maximum-iterations=i" => \$maximum_iterations,
"starting-weights-from-ini!" => \$starting_weights_from_ini,
"pairwise-ranked" => \$___PAIRWISE_RANKED_OPTIMIZER
) or exit(1);
# the 4 required parameters can be supplied on the command line directly
@ -311,6 +313,15 @@ my $mert_mert_cmd = "$mertdir/mert";
die "Not executable: $mert_extract_cmd" if ! -x $mert_extract_cmd;
die "Not executable: $mert_mert_cmd" if ! -x $mert_mert_cmd;
my $pro_optimizer = "$mertdir/megam_i686.opt"; # or set to your installation
if ($___PAIRWISE_RANKED_OPTIMIZER && ! -x $pro_optimizer) {
print "did not find $pro_optimizer, installing it in $mertdir\n";
`cd $mertdir; wget http://www.cs.utah.edu/~hal/megam/megam_i686.opt.gz;`;
`gunzip $pro_optimizer.gz`;
`chmod +x $pro_optimizer`;
die("ERROR: Installation of megam_i686.opt failed! Install by hand from http://www.cs.utah.edu/~hal/megam/") unless -x $pro_optimizer;
}
$mertargs = "" if !defined $mertargs;
my $scconfig = undef;
@ -754,9 +765,14 @@ while(1) {
$cmd = $cmd." --ifile run$run.$weights_in_file";
}
if ($___PAIRWISE_RANKED_OPTIMIZER) {
$cmd .= " --pro pro.data ; echo 'not used' > $weights_out_file; ~/statmt/project/megam/megam_i686.opt -fvals -maxi 30 -nobias binary pro.data";
}
if (defined $___JOBS && $___JOBS > 0) {
safesystem("$qsubwrapper $pass_old_sge -command='$cmd' -stdout=$mert_outfile -stderr=$mert_logfile -queue-parameter=\"$queue_flags\"") or die "Failed to start mert (via qsubwrapper $qsubwrapper)";
} else {
}
else {
safesystem("$cmd > $mert_outfile 2> $mert_logfile") or die "Failed to run mert";
}
die "Optimization failed, file $weights_out_file does not exist or is empty"
@ -766,6 +782,7 @@ while(1) {
# backup copies
safesystem ("\\cp -f extract.err run$run.extract.err") or die;
safesystem ("\\cp -f extract.out run$run.extract.out") or die;
if ($___PAIRWISE_RANKED_OPTIMIZER) { safesystem ("\\cp -f pro.data run$run.pro.data") or die; }
safesystem ("\\cp -f $mert_outfile run$run.$mert_outfile") or die;
safesystem ("\\cp -f $mert_logfile run$run.$mert_logfile") or die;
safesystem ("touch $mert_logfile run$run.$mert_logfile") or die;
@ -775,15 +792,32 @@ while(1) {
$bestpoint = undef;
$devbleu = undef;
open(IN,"run$run.$mert_logfile") or die "Can't open run$run.$mert_logfile";
while (<IN>) {
if (/Best point:\s*([\s\d\.\-e]+?)\s*=> ([\-\d\.]+)/) {
$bestpoint = $1;
$devbleu = $2;
last;
if ($___PAIRWISE_RANKED_OPTIMIZER) {
open(IN,"run$run.$mert_outfile") or die "Can't open run$run.$mert_outfile";
my (@WEIGHT,$sum);
foreach (@CURR) { push @WEIGHT, 0; }
while(<IN>) {
if (/^F(\d+) ([\-\.\de]+)/) {
$WEIGHT[$1] = $2;
$sum += abs($2);
}
}
$devbleu = "unknown";
foreach (@WEIGHT) { $_ /= $sum; }
$bestpoint = join(" ",@WEIGHT);
close IN;
}
else {
open(IN,"run$run.$mert_logfile") or die "Can't open run$run.$mert_logfile";
while (<IN>) {
if (/Best point:\s*([\s\d\.\-e]+?)\s*=> ([\-\d\.]+)/) {
$bestpoint = $1;
$devbleu = $2;
last;
}
}
close IN;
}
close IN;
die "Failed to parse mert.log, missed Best point there."
if !defined $bestpoint || !defined $devbleu;
print "($run) BEST at $run: $bestpoint => $devbleu at ".`date`;
@ -808,7 +842,6 @@ while(1) {
print F $run."\n";
close F;
if ($shouldstop) {
print STDERR "None of the weights changed more than $minimum_required_change_in_weights. Stopping.\n";
last;
@ -939,7 +972,7 @@ sub run_decoder {
print "decoder_config = $decoder_config\n";
# run the decoder
my $nBest_cmd = "-n-best-size $___N_BEST_LIST_SIZE";
my $nBest_cmd = "-n-best-size $___N_BEST_LIST_SIZE";
my $decoder_cmd;
if (defined $___JOBS && $___JOBS > 0) {