From 1f6f8b4abb7658b424c3429c6a74df1000245a86 Mon Sep 17 00:00:00 2001 From: ehasler Date: Sat, 30 Apr 2011 16:17:37 +0000 Subject: [PATCH] update training script git-svn-id: http://svn.statmt.org/repository/mira@3885 cc96ff50-19ce-11e0-b349-13d7f0bd23df --- mira/training-expt.perl | 122 +++++++++++++++++++++++++++++----------- 1 file changed, 89 insertions(+), 33 deletions(-) diff --git a/mira/training-expt.perl b/mira/training-expt.perl index 55e15906e..425e0818f 100755 --- a/mira/training-expt.perl +++ b/mira/training-expt.perl @@ -77,7 +77,6 @@ my $trainer_exe = ¶m_required("train.trainer"); #my $weights_file = ¶m_required("train.weights-file"); #&check_exists("weights file ", $weights_file); - #optional training parameters my $epochs = ¶m("train.epochs", 2); my $learner = ¶m("train.learner", "mira"); @@ -87,9 +86,11 @@ my $continue_from_epoch = ¶m("train.continue-from-epoch", 0); my $by_node = ¶m("train.by-node",0); my $slots = ¶m("train.slots",8); my $jobs = ¶m("train.jobs",8); -my $mixing_frequency = ¶m("train.mixing-frequency",1); -my $weight_dump_frequency = ¶m("train.weight-dump-frequency",1); - +my $mixing_frequency = ¶m("train.mixing-frequency",0); +my $weight_dump_frequency = ¶m("train.weight-dump-frequency",0); +my $burn_in = ¶m("train.burn-in",0); +my $burn_in_input_file = ¶m("train.burn-in-input-file"); +my $burn_in_reference_files = ¶m("train.burn-in-reference-files"); #test configuration my ($test_input_file, $test_reference_file,$test_ini_file,$bleu_script,$use_moses); @@ -112,33 +113,54 @@ my $skip_test = ¶m("test.skip-test",0); my $skip_dev = ¶m("test.skip-dev",0); -# adjust test frequency when using batches > 1 -if ($batch > 1) { - $mixing_frequency = 1; -} - # check that number of jobs, dump frequency and number of input sentences are compatible # shard size = number of input sentences / number of jobs, ensure shard size >= dump frequency my $result = `wc -l $input_file`; my @result = split(/\s/, $result); my $inputSize = $result[0]; my $shardSize = $inputSize / $jobs; -if ($shardSize < $mixing_frequency) { - $mixing_frequency = int($shardSize); - if ($mixing_frequency == 0) { - $mixing_frequency = 1; - } +if ($mixing_frequency != 0) { + if ($shardSize < $mixing_frequency) { + $mixing_frequency = int($shardSize); + if ($mixing_frequency == 0) { + $mixing_frequency = 1; + } - print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n"; + print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n"; + } } -if ($shardSize < $weight_dump_frequency) { - $weight_dump_frequency = int($shardSize); - if ($weight_dump_frequency == 0) { - $weight_dump_frequency = 1; +if ($weight_dump_frequency != 0) { + if ($shardSize < $weight_dump_frequency) { + $weight_dump_frequency = int($shardSize); + if ($weight_dump_frequency == 0) { + $weight_dump_frequency = 1; + } + + print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n"; } +} - print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n"; +if ($mixing_frequency != 0) { + if ($mixing_frequency > ($shardSize/$batch)) { + $mixing_frequency = int($shardSize/$batch); + if ($mixing_frequency == 0) { + $mixing_frequency = 1; + } + + print "Warning: mixing frequency must not be larger than (shard size/batch size), setting mixing frequency to $mixing_frequency\n"; + } +} + +if ($weight_dump_frequency != 0) { + if ($weight_dump_frequency > ($shardSize/$batch)) { + $weight_dump_frequency = int($shardSize/$batch); + if ($weight_dump_frequency == 0) { + $weight_dump_frequency = 1; + } + + print "Warning: weight dump frequency must not be larger than (shard size/batch size), setting weight dump frequency to $weight_dump_frequency\n"; + } } #file names @@ -169,13 +191,30 @@ for my $ref (@refs) { print TRAIN "-r $ref "; } print TRAIN "\\\n"; +if ($burn_in) { + print TRAIN "--burn-in 1 \\\n"; + print TRAIN "--burn-in-input-file $burn_in_input_file \\\n"; + my @refs; + if (ref($burn_in_reference_files) eq 'ARRAY') { + @refs = @$burn_in_reference_files; + } else { + @refs = glob $burn_in_reference_files; + } + for my $ref (@refs) { + &check_exists("burn-in ref file", $ref); + print TRAIN "--burn-in-reference-files $ref "; + } + print TRAIN "\\\n"; +} #if ($weights_file) { # print TRAIN "-w $weights_file \\\n"; #} print TRAIN "-l $learner \\\n"; print TRAIN "--weight-dump-stem $weight_file_stem \\\n"; print TRAIN "--mixing-frequency $mixing_frequency \\\n"; -print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n"; +if ($weight_dump_frequency != -1) { + print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n"; +} print TRAIN "--epochs $epochs \\\n"; print TRAIN "-b $batch \\\n"; print TRAIN "--decoder-settings \"$decoder_settings\" \\\n"; @@ -213,26 +252,26 @@ while(1) { my($epoch, $epoch_slice); $train_iteration += 1; my $new_weight_file = "$working_dir/$weight_file_stem" . "_"; - my $totalAverageWeightFile; - if ($mixing_frequency == 1) { + if ($weight_dump_frequency == 0) { + print "No weights, no testing..\n"; + exit(0); + } + + if ($weight_dump_frequency == 1) { if ($train_iteration < 10) { $new_weight_file .= "0".$train_iteration; - $totalAverageWeightFile = $new_weight_file."_averageTotal"; } else { $new_weight_file .= $train_iteration; - $totalAverageWeightFile = $new_weight_file."_averageTotal"; } } else { - #my $epoch = 1 + int $train_iteration / $mixing_frequency; - $epoch = int $train_iteration / $mixing_frequency; - $epoch_slice = $train_iteration % $mixing_frequency; + #my $epoch = 1 + int $train_iteration / $weight_dump_frequency; + $epoch = int $train_iteration / $weight_dump_frequency; + $epoch_slice = $train_iteration % $weight_dump_frequency; if ($epoch < 10) { - $totalAverageWeightFile = $new_weight_file."0".$epoch."_averageTotal"; $new_weight_file .= "0".$epoch."_".$epoch_slice; } else { - $totalAverageWeightFile = $new_weight_file.$epoch."_averageTotal"; $new_weight_file .= $epoch."_".$epoch_slice; } } @@ -283,7 +322,7 @@ sub createTestScriptAndSubmit { my $output_file; my $output_error_file; my $bleu_file; - if ($mixing_frequency == 1) { + if ($weight_dump_frequency == 1) { if ($train_iteration < 10) { $output_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".out"; $output_error_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".err"; @@ -335,6 +374,8 @@ sub createTestScriptAndSubmit { if (! (open WEIGHTS, "$core_weight_file")) { die "Unable to open weights file $core_weight_file\n"; } + + my $readCoreWeights = 0; my $readExtraWeights = 0; my %extra_weights; while() { @@ -347,17 +388,23 @@ sub createTestScriptAndSubmit { } else { if ($name eq "WordPenalty") { $wordpenalty_weight = $value; + $readCoreWeights += 1; } elsif ($name =~ /^PhraseModel/) { push @phrasemodel_weights,$value; + $readCoreWeights += scalar @phrasemodel_weights; } elsif ($name =~ /^LM\:2/) { $lm2_weight = $value; + $readCoreWeights += 1; } elsif ($name =~ /^LM/) { $lm_weight = $value; + $readCoreWeights += 1; } elsif ($name eq "Distortion") { $distortion_weight = $value; + $readCoreWeights += 1; } elsif ($name =~ /^LexicalReordering/) { push @lexicalreordering_weights,$value; + $readCoreWeights += scalar @lexicalreordering_weights; } else { $extra_weights{$name} = $value; $readExtraWeights += 1; @@ -368,8 +415,11 @@ sub createTestScriptAndSubmit { print "Number of extra weights read: ".$readExtraWeights."\n"; - die "LM weight not defined" unless defined $lm_weight; - + if ($readCoreWeights == 0) { + print "No core weights defined.. skipping weight file\n"; + return; + } + # If there was a core weight file, then we have to load the weights # from the new weight file if ($core_weight_file ne $new_weight_file) { @@ -470,6 +520,12 @@ sub createTestScriptAndSubmit { print TEST "#\$ -o $test_out\n"; print TEST "#\$ -e $test_err\n"; print TEST "\n"; + if ($have_sge) { +# some eddie specific stuff + print TEST ". /etc/profile.d/modules.sh\n"; + print TEST "module load openmpi/ethernet/gcc/latest\n"; + print TEST "export LD_LIBRARY_PATH=/exports/informatics/inf_iccs_smt/shared/boost/lib:\$LD_LIBRARY_PATH\n"; + } print TEST "$test_exe $decoder_settings -i $input_file -f $new_ini_file "; if ($extra_weight_file) { print TEST "-weight-file $extra_weight_file ";