diff --git a/mira/Main.cpp b/mira/Main.cpp index ddcab90ad..3a10e79f0 100644 --- a/mira/Main.cpp +++ b/mira/Main.cpp @@ -122,6 +122,7 @@ int main(int argc, char** argv) { string decode_filename; size_t update_scheme; bool separateUpdates, batchEqualsShard; + bool mixByAveraging; po::options_description desc("Allowed options"); desc.add_options() ("slack", po::value(&slack)->default_value(0.01), "Use slack in optimiser") @@ -172,6 +173,7 @@ int main(int argc, char** argv) { ("min-weight-change", po::value(&min_weight_change)->default_value(0.01), "Set minimum weight change for stopping criterion") ("mira-learning-rate", po::value(&mira_learning_rate)->default_value(1), "Learning rate for MIRA (fixed or flexible)") ("mixing-frequency", po::value(&mixingFrequency)->default_value(5), "How often per epoch to mix weights, when using mpi") + ("mix-by-averaging", po::value(&mixByAveraging)->default_value(true), "Average weights by the number of processes") ("model-hope-fear", po::value(&model_hope_fear)->default_value(false), "Use model, hope and fear translations for optimisation") ("nbest,n", po::value(&n)->default_value(1), "Number of translations in n-best list") ("normalise", po::value(&normaliseWeights)->default_value(false), "Whether to normalise the updated weights before passing them to the decoder") @@ -897,7 +899,8 @@ int main(int argc, char** argv) { mpi::reduce(world, mosesWeights, mixedWeights, SCCPlus(), 0); if (rank == 0) { // divide by number of processes - mixedWeights.DivideEquals(size); + if (mixByAveraging) + mixedWeights.DivideEquals(size); // normalise weights after averaging if (normaliseWeights) { @@ -941,7 +944,8 @@ int main(int argc, char** argv) { #endif if (rank == 0 && !weightDumpStem.empty()) { // divide by number of processes - mixedAverageWeights.DivideEquals(size); + if (mixByAveraging) + mixedAverageWeights.DivideEquals(size); // normalise weights after averaging if (normaliseWeights) {