update training script

git-svn-id: http://svn.statmt.org/repository/mira@3885 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
ehasler 2011-04-30 16:17:37 +00:00 committed by Ondrej Bojar
parent 479ba8d160
commit 1f6f8b4abb

View File

@ -77,7 +77,6 @@ my $trainer_exe = &param_required("train.trainer");
#my $weights_file = &param_required("train.weights-file");
#&check_exists("weights file ", $weights_file);
#optional training parameters
my $epochs = &param("train.epochs", 2);
my $learner = &param("train.learner", "mira");
@ -87,9 +86,11 @@ my $continue_from_epoch = &param("train.continue-from-epoch", 0);
my $by_node = &param("train.by-node",0);
my $slots = &param("train.slots",8);
my $jobs = &param("train.jobs",8);
my $mixing_frequency = &param("train.mixing-frequency",1);
my $weight_dump_frequency = &param("train.weight-dump-frequency",1);
my $mixing_frequency = &param("train.mixing-frequency",0);
my $weight_dump_frequency = &param("train.weight-dump-frequency",0);
my $burn_in = &param("train.burn-in",0);
my $burn_in_input_file = &param("train.burn-in-input-file");
my $burn_in_reference_files = &param("train.burn-in-reference-files");
#test configuration
my ($test_input_file, $test_reference_file,$test_ini_file,$bleu_script,$use_moses);
@ -112,33 +113,54 @@ my $skip_test = &param("test.skip-test",0);
my $skip_dev = &param("test.skip-dev",0);
# adjust test frequency when using batches > 1
if ($batch > 1) {
$mixing_frequency = 1;
}
# check that number of jobs, dump frequency and number of input sentences are compatible
# shard size = number of input sentences / number of jobs, ensure shard size >= dump frequency
my $result = `wc -l $input_file`;
my @result = split(/\s/, $result);
my $inputSize = $result[0];
my $shardSize = $inputSize / $jobs;
if ($shardSize < $mixing_frequency) {
$mixing_frequency = int($shardSize);
if ($mixing_frequency == 0) {
$mixing_frequency = 1;
}
if ($mixing_frequency != 0) {
if ($shardSize < $mixing_frequency) {
$mixing_frequency = int($shardSize);
if ($mixing_frequency == 0) {
$mixing_frequency = 1;
}
print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n";
print "Warning: mixing frequency must not be larger than shard size, setting mixing frequency to $mixing_frequency\n";
}
}
if ($shardSize < $weight_dump_frequency) {
$weight_dump_frequency = int($shardSize);
if ($weight_dump_frequency == 0) {
$weight_dump_frequency = 1;
if ($weight_dump_frequency != 0) {
if ($shardSize < $weight_dump_frequency) {
$weight_dump_frequency = int($shardSize);
if ($weight_dump_frequency == 0) {
$weight_dump_frequency = 1;
}
print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n";
}
}
print "Warning: weight dump frequency must not be larger than shard size, setting weight dump frequency to $weight_dump_frequency\n";
if ($mixing_frequency != 0) {
if ($mixing_frequency > ($shardSize/$batch)) {
$mixing_frequency = int($shardSize/$batch);
if ($mixing_frequency == 0) {
$mixing_frequency = 1;
}
print "Warning: mixing frequency must not be larger than (shard size/batch size), setting mixing frequency to $mixing_frequency\n";
}
}
if ($weight_dump_frequency != 0) {
if ($weight_dump_frequency > ($shardSize/$batch)) {
$weight_dump_frequency = int($shardSize/$batch);
if ($weight_dump_frequency == 0) {
$weight_dump_frequency = 1;
}
print "Warning: weight dump frequency must not be larger than (shard size/batch size), setting weight dump frequency to $weight_dump_frequency\n";
}
}
#file names
@ -169,13 +191,30 @@ for my $ref (@refs) {
print TRAIN "-r $ref ";
}
print TRAIN "\\\n";
if ($burn_in) {
print TRAIN "--burn-in 1 \\\n";
print TRAIN "--burn-in-input-file $burn_in_input_file \\\n";
my @refs;
if (ref($burn_in_reference_files) eq 'ARRAY') {
@refs = @$burn_in_reference_files;
} else {
@refs = glob $burn_in_reference_files;
}
for my $ref (@refs) {
&check_exists("burn-in ref file", $ref);
print TRAIN "--burn-in-reference-files $ref ";
}
print TRAIN "\\\n";
}
#if ($weights_file) {
# print TRAIN "-w $weights_file \\\n";
#}
print TRAIN "-l $learner \\\n";
print TRAIN "--weight-dump-stem $weight_file_stem \\\n";
print TRAIN "--mixing-frequency $mixing_frequency \\\n";
print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n";
if ($weight_dump_frequency != -1) {
print TRAIN "--weight-dump-frequency $weight_dump_frequency \\\n";
}
print TRAIN "--epochs $epochs \\\n";
print TRAIN "-b $batch \\\n";
print TRAIN "--decoder-settings \"$decoder_settings\" \\\n";
@ -213,26 +252,26 @@ while(1) {
my($epoch, $epoch_slice);
$train_iteration += 1;
my $new_weight_file = "$working_dir/$weight_file_stem" . "_";
my $totalAverageWeightFile;
if ($mixing_frequency == 1) {
if ($weight_dump_frequency == 0) {
print "No weights, no testing..\n";
exit(0);
}
if ($weight_dump_frequency == 1) {
if ($train_iteration < 10) {
$new_weight_file .= "0".$train_iteration;
$totalAverageWeightFile = $new_weight_file."_averageTotal";
}
else {
$new_weight_file .= $train_iteration;
$totalAverageWeightFile = $new_weight_file."_averageTotal";
}
} else {
#my $epoch = 1 + int $train_iteration / $mixing_frequency;
$epoch = int $train_iteration / $mixing_frequency;
$epoch_slice = $train_iteration % $mixing_frequency;
#my $epoch = 1 + int $train_iteration / $weight_dump_frequency;
$epoch = int $train_iteration / $weight_dump_frequency;
$epoch_slice = $train_iteration % $weight_dump_frequency;
if ($epoch < 10) {
$totalAverageWeightFile = $new_weight_file."0".$epoch."_averageTotal";
$new_weight_file .= "0".$epoch."_".$epoch_slice;
}
else {
$totalAverageWeightFile = $new_weight_file.$epoch."_averageTotal";
$new_weight_file .= $epoch."_".$epoch_slice;
}
}
@ -283,7 +322,7 @@ sub createTestScriptAndSubmit {
my $output_file;
my $output_error_file;
my $bleu_file;
if ($mixing_frequency == 1) {
if ($weight_dump_frequency == 1) {
if ($train_iteration < 10) {
$output_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".out";
$output_error_file = $working_dir."/".$name."_0".$train_iteration.$suffix."_$testtype".".err";
@ -335,6 +374,8 @@ sub createTestScriptAndSubmit {
if (! (open WEIGHTS, "$core_weight_file")) {
die "Unable to open weights file $core_weight_file\n";
}
my $readCoreWeights = 0;
my $readExtraWeights = 0;
my %extra_weights;
while(<WEIGHTS>) {
@ -347,17 +388,23 @@ sub createTestScriptAndSubmit {
} else {
if ($name eq "WordPenalty") {
$wordpenalty_weight = $value;
$readCoreWeights += 1;
} elsif ($name =~ /^PhraseModel/) {
push @phrasemodel_weights,$value;
$readCoreWeights += scalar @phrasemodel_weights;
} elsif ($name =~ /^LM\:2/) {
$lm2_weight = $value;
$readCoreWeights += 1;
}
elsif ($name =~ /^LM/) {
$lm_weight = $value;
$readCoreWeights += 1;
} elsif ($name eq "Distortion") {
$distortion_weight = $value;
$readCoreWeights += 1;
} elsif ($name =~ /^LexicalReordering/) {
push @lexicalreordering_weights,$value;
$readCoreWeights += scalar @lexicalreordering_weights;
} else {
$extra_weights{$name} = $value;
$readExtraWeights += 1;
@ -368,8 +415,11 @@ sub createTestScriptAndSubmit {
print "Number of extra weights read: ".$readExtraWeights."\n";
die "LM weight not defined" unless defined $lm_weight;
if ($readCoreWeights == 0) {
print "No core weights defined.. skipping weight file\n";
return;
}
# If there was a core weight file, then we have to load the weights
# from the new weight file
if ($core_weight_file ne $new_weight_file) {
@ -470,6 +520,12 @@ sub createTestScriptAndSubmit {
print TEST "#\$ -o $test_out\n";
print TEST "#\$ -e $test_err\n";
print TEST "\n";
if ($have_sge) {
# some eddie specific stuff
print TEST ". /etc/profile.d/modules.sh\n";
print TEST "module load openmpi/ethernet/gcc/latest\n";
print TEST "export LD_LIBRARY_PATH=/exports/informatics/inf_iccs_smt/shared/boost/lib:\$LD_LIBRARY_PATH\n";
}
print TEST "$test_exe $decoder_settings -i $input_file -f $new_ini_file ";
if ($extra_weight_file) {
print TEST "-weight-file $extra_weight_file ";