mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-09-20 07:42:21 +03:00
update training script
git-svn-id: http://svn.statmt.org/repository/mira@3885 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
parent
479ba8d160
commit
1f6f8b4abb
@ -77,7 +77,6 @@ my $trainer_exe = ¶m_required("train.trainer");
|
||||
#my $weights_file = ¶m_required("train.weights-file");
|
||||
#&check_exists("weights file ", $weights_file);
|
||||
|
||||
|
||||
#optional training parameters
|
||||
my $epochs = ¶m("train.epochs", 2);
|
||||
my $learner = ¶m("train.learner", "mira");
|
||||
@ -87,9 +86,11 @@ my $continue_from_epoch = ¶m("train.continue-from-epoch", 0);
|
||||
my $by_node = ¶m("train.by-node",0);
|
||||
my $slots = ¶m("train.slots",8);
|
||||
my $jobs = ¶m("train.jobs",8);
|
||||
my $mixing_frequency = ¶m("train.mixing-frequency",1);
|
||||
my $weight_dump_frequency = ¶m("train.weight-dump-frequency",1);
|
||||
|
||||
my $mixing_frequency = ¶m("train.mixing-frequency",0);
|
||||
my $weight_dump_frequency = ¶m("train.weight-dump-frequency",0);
|
||||
my $burn_in = ¶m("train.burn-in",0);
|
||||
my $burn_in_input_file = ¶m("train.burn-in-input-file");
|
||||
my $burn_in_reference_files = ¶m("train.burn-in-reference-files");
|
||||
|
||||
#test configuration
|
||||
my ($test_input_file, $test_reference_file,$test_ini_file,$bleu_script,$use_moses);
|
||||
@ -112,33 +113,54 @@ my $skip_test = ¶m("test.skip-test",0);
|
||||
my $skip_dev = ¶m("test.skip-dev",0);
|
||||
|
||||
|
||||
# adjust test frequency when using batches > 1
|
||||
if ($batch > 1) {
|
||||
$mixing_frequency = 1;
|
||||
}
|
||||
|
||||
# check that number of jobs, dump frequency and number of input sentences are compatible
|
||||
# shard size = number of input sentences / number of jobs, ensure shard size >= dump frequency
|
||||
my $result = `wc -l $input_file`;
|
||||
my @result = split(/\s/, $result);
|
||||
my $inputSize = $result[0];
|
||||
my $shardSize = $inputSize / $jobs;
|
||||
if ($shardSize < $mixing_frequency) {
|
||||
$mixing_frequency = int($shardSize);
|
||||
if ($mixing_frequency == 0) {
|
||||
$mixing_frequency = 1;
|
||||
}
|
||||
if ($mixing_frequency != 0) {
|
||||
if ($shardSize < $mixing_frequency) {
|
||||
$mixing_frequency = int($shardSize);
|
||||
if ($mixing_frequency == 0) {
|
||||
$mixing_frequency = 1;
|
||||
}
|
||||
|
||||
print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n";
|
||||
print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n";
|
||||
}
|
||||
}
|
||||
|
||||
if ($shardSize < $weight_dump_frequency) {
|
||||
$weight_dump_frequency = int($shardSize);
|
||||
if ($weight_dump_frequency == 0) {
|
||||
$weight_dump_frequency = 1;
|
||||
if ($weight_dump_frequency != 0) {
|
||||
if ($shardSize < $weight_dump_frequency) {
|
||||
$weight_dump_frequency = int($shardSize);
|
||||
if ($weight_dump_frequency == 0) {
|
||||
$weight_dump_frequency = 1;
|
||||
}
|
||||
|
||||
print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n";
|
||||
}
|
||||
}
|
||||
|
||||
print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n";
|
||||
if ($mixing_frequency != 0) {
|
||||
if ($mixing_frequency > ($shardSize/$batch)) {
|
||||
$mixing_frequency = int($shardSize/$batch);
|
||||
if ($mixing_frequency == 0) {
|
||||
$mixing_frequency = 1;
|
||||
}
|
||||
|
||||
print "Warning: mixing frequency must not be larger than (shard size/batch size), setting mixing frequency to $mixing_frequency\n";
|
||||
}
|
||||
}
|
||||
|
||||
if ($weight_dump_frequency != 0) {
|
||||
if ($weight_dump_frequency > ($shardSize/$batch)) {
|
||||
$weight_dump_frequency = int($shardSize/$batch);
|
||||
if ($weight_dump_frequency == 0) {
|
||||
$weight_dump_frequency = 1;
|
||||
}
|
||||
|
||||
print "Warning: weight dump frequency must not be larger than (shard size/batch size), setting weight dump frequency to $weight_dump_frequency\n";
|
||||
}
|
||||
}
|
||||
|
||||
#file names
|
||||
@ -169,13 +191,30 @@ for my $ref (@refs) {
|
||||
print TRAIN "-r $ref ";
|
||||
}
|
||||
print TRAIN "\\\n";
|
||||
if ($burn_in) {
|
||||
print TRAIN "--burn-in 1 \\\n";
|
||||
print TRAIN "--burn-in-input-file $burn_in_input_file \\\n";
|
||||
my @refs;
|
||||
if (ref($burn_in_reference_files) eq 'ARRAY') {
|
||||
@refs = @$burn_in_reference_files;
|
||||
} else {
|
||||
@refs = glob $burn_in_reference_files;
|
||||
}
|
||||
for my $ref (@refs) {
|
||||
&check_exists("burn-in ref file", $ref);
|
||||
print TRAIN "--burn-in-reference-files $ref ";
|
||||
}
|
||||
print TRAIN "\\\n";
|
||||
}
|
||||
#if ($weights_file) {
|
||||
# print TRAIN "-w $weights_file \\\n";
|
||||
#}
|
||||
print TRAIN "-l $learner \\\n";
|
||||
print TRAIN "--weight-dump-stem $weight_file_stem \\\n";
|
||||
print TRAIN "--mixing-frequency $mixing_frequency \\\n";
|
||||
print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n";
|
||||
if ($weight_dump_frequency != -1) {
|
||||
print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n";
|
||||
}
|
||||
print TRAIN "--epochs $epochs \\\n";
|
||||
print TRAIN "-b $batch \\\n";
|
||||
print TRAIN "--decoder-settings \"$decoder_settings\" \\\n";
|
||||
@ -213,26 +252,26 @@ while(1) {
|
||||
my($epoch, $epoch_slice);
|
||||
$train_iteration += 1;
|
||||
my $new_weight_file = "$working_dir/$weight_file_stem" . "_";
|
||||
my $totalAverageWeightFile;
|
||||
if ($mixing_frequency == 1) {
|
||||
if ($weight_dump_frequency == 0) {
|
||||
print "No weights, no testing..\n";
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if ($weight_dump_frequency == 1) {
|
||||
if ($train_iteration < 10) {
|
||||
$new_weight_file .= "0".$train_iteration;
|
||||
$totalAverageWeightFile = $new_weight_file."_averageTotal";
|
||||
}
|
||||
else {
|
||||
$new_weight_file .= $train_iteration;
|
||||
$totalAverageWeightFile = $new_weight_file."_averageTotal";
|
||||
}
|
||||
} else {
|
||||
#my $epoch = 1 + int $train_iteration / $mixing_frequency;
|
||||
$epoch = int $train_iteration / $mixing_frequency;
|
||||
$epoch_slice = $train_iteration % $mixing_frequency;
|
||||
#my $epoch = 1 + int $train_iteration / $weight_dump_frequency;
|
||||
$epoch = int $train_iteration / $weight_dump_frequency;
|
||||
$epoch_slice = $train_iteration % $weight_dump_frequency;
|
||||
if ($epoch < 10) {
|
||||
$totalAverageWeightFile = $new_weight_file."0".$epoch."_averageTotal";
|
||||
$new_weight_file .= "0".$epoch."_".$epoch_slice;
|
||||
}
|
||||
else {
|
||||
$totalAverageWeightFile = $new_weight_file.$epoch."_averageTotal";
|
||||
$new_weight_file .= $epoch."_".$epoch_slice;
|
||||
}
|
||||
}
|
||||
@ -283,7 +322,7 @@ sub createTestScriptAndSubmit {
|
||||
my $output_file;
|
||||
my $output_error_file;
|
||||
my $bleu_file;
|
||||
if ($mixing_frequency == 1) {
|
||||
if ($weight_dump_frequency == 1) {
|
||||
if ($train_iteration < 10) {
|
||||
$output_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".out";
|
||||
$output_error_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".err";
|
||||
@ -335,6 +374,8 @@ sub createTestScriptAndSubmit {
|
||||
if (! (open WEIGHTS, "$core_weight_file")) {
|
||||
die "Unable to open weights file $core_weight_file\n";
|
||||
}
|
||||
|
||||
my $readCoreWeights = 0;
|
||||
my $readExtraWeights = 0;
|
||||
my %extra_weights;
|
||||
while(<WEIGHTS>) {
|
||||
@ -347,17 +388,23 @@ sub createTestScriptAndSubmit {
|
||||
} else {
|
||||
if ($name eq "WordPenalty") {
|
||||
$wordpenalty_weight = $value;
|
||||
$readCoreWeights += 1;
|
||||
} elsif ($name =~ /^PhraseModel/) {
|
||||
push @phrasemodel_weights,$value;
|
||||
$readCoreWeights += scalar @phrasemodel_weights;
|
||||
} elsif ($name =~ /^LM\:2/) {
|
||||
$lm2_weight = $value;
|
||||
$readCoreWeights += 1;
|
||||
}
|
||||
elsif ($name =~ /^LM/) {
|
||||
$lm_weight = $value;
|
||||
$readCoreWeights += 1;
|
||||
} elsif ($name eq "Distortion") {
|
||||
$distortion_weight = $value;
|
||||
$readCoreWeights += 1;
|
||||
} elsif ($name =~ /^LexicalReordering/) {
|
||||
push @lexicalreordering_weights,$value;
|
||||
$readCoreWeights += scalar @lexicalreordering_weights;
|
||||
} else {
|
||||
$extra_weights{$name} = $value;
|
||||
$readExtraWeights += 1;
|
||||
@ -368,8 +415,11 @@ sub createTestScriptAndSubmit {
|
||||
|
||||
print "Number of extra weights read: ".$readExtraWeights."\n";
|
||||
|
||||
die "LM weight not defined" unless defined $lm_weight;
|
||||
|
||||
if ($readCoreWeights == 0) {
|
||||
print "No core weights defined.. skipping weight file\n";
|
||||
return;
|
||||
}
|
||||
|
||||
# If there was a core weight file, then we have to load the weights
|
||||
# from the new weight file
|
||||
if ($core_weight_file ne $new_weight_file) {
|
||||
@ -470,6 +520,12 @@ sub createTestScriptAndSubmit {
|
||||
print TEST "#\$ -o $test_out\n";
|
||||
print TEST "#\$ -e $test_err\n";
|
||||
print TEST "\n";
|
||||
if ($have_sge) {
|
||||
# some eddie specific stuff
|
||||
print TEST ". /etc/profile.d/modules.sh\n";
|
||||
print TEST "module load openmpi/ethernet/gcc/latest\n";
|
||||
print TEST "export LD_LIBRARY_PATH=/exports/informatics/inf_iccs_smt/shared/boost/lib:\$LD_LIBRARY_PATH\n";
|
||||
}
|
||||
print TEST "$test_exe $decoder_settings -i $input_file -f $new_ini_file ";
|
||||
if ($extra_weight_file) {
|
||||
print TEST "-weight-file $extra_weight_file ";
|
||||
|
Loading…
Reference in New Issue
Block a user