Merged PR 15896: Add --after N option to supersede --after-batches and --after-epochs

Replace `--after-batches N` and `--after-epochs N` with `--after Nu/Ne` which allows to specify updates, epochs, target labels with units, e.g.:
* `--after 30Gt` or `--after 50ku` or `--after 10e`
* Can also combine multiple criteria: `--after 30Gt,50ku,10e` and will stop when whichever hits first

Changes default `cost-type` from `ce-mean` to `ce-sum` and turns `display-label-counts` on by default.
This commit is contained in:
Martin Junczys-Dowmunt 2020-10-29 20:16:19 +00:00
parent ae866af035
commit 160b36cec8
6 changed files with 35 additions and 9 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Add --after option which is meant to replace --after-batches and --after-epochs and can take label based criteria
- Add --transformer-postprocess-top option to enable correctly normalized prenorm behavior
- Add --task transformer-base-prenorm and --task transformer-big-prenorm
- Turing and Ampere GPU optimisation support, if the CUDA version supports it.

@ -1 +1 @@
Subproject commit cdad78089484d7817d91c803d6fc7049328e20db
Subproject commit 75977846abfccd29941e4bfd3c615a111599f7f4

View File

@ -10,9 +10,12 @@ namespace marian {
namespace cli {
// clang-format off
const std::unordered_set<std::string> DEPRECIATED_OPTIONS = {
const std::unordered_set<std::string> DEPRECATED_OPTIONS = {
"version",
"special-vocab"
"special-vocab",
// @TODO: uncomment once we actually deprecate them.
// "after-batches",
// "after-epochs"
};
// clang-format on
@ -177,7 +180,7 @@ void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority prio
if(cmdOptions.count(key))
continue;
// Skip options that might exist in config files generated by older versions of Marian
if(DEPRECIATED_OPTIONS.count(key))
if(DEPRECATED_OPTIONS.count(key))
continue;
// Check if an incoming option has been defined in CLI

View File

@ -347,7 +347,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
auto previous_group = cli.switchGroup("Training options");
// clang-format off
cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-sum");
cli.add<std::string>("--multi-loss-type",
"How to accumulate multi-objective losses: sum, scaled, mean", "sum");
cli.add<bool>("--unlikelihood-loss",
@ -375,17 +375,24 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
10000000);
#endif
// scheduling options
// @TODO: these should be re-defined as aliases for `--after` but the current frame work matches on value, so not doable.
cli.add<size_t>("--after-epochs,-e",
"Finish after this many epochs, 0 is infinity");
"Finish after this many epochs, 0 is infinity (deprecated, '--after-epochs N' corresponds to '--after Ne')"); // @TODO: replace with alias
cli.add<size_t>("--after-batches",
"Finish after this many batch updates, 0 is infinity");
"Finish after this many batch updates, 0 is infinity (deprecated, '--after-batches N' corresponds to '--after Nu')"); // @TODO: replace with alias
cli.add<std::string>("--after,-a",
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
"0e");
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
"Display information every arg updates (append 't' for every arg target labels)",
"1000u");
cli.add<size_t>("--disp-first",
"Display information for the first arg updates");
cli.add<bool>("--disp-label-counts",
"Display label counts when logging loss progress");
"Display label counts when logging loss progress",
true);
// cli.add<int>("--disp-label-index",
// "Display label counts based on i-th input stream (-1 is last)", -1);
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
@ -901,7 +908,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
cli::mode ConfigParser::getMode() const { return mode_; }
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
cmdLine_ = escapeCmdLine(argc,argv);
// parse command-line options and fill wrapped YAML config

View File

@ -415,6 +415,7 @@ double parseNumber(std::string param) {
if(!param.empty() && param.back() >= 'A') {
switch(param.back()) {
case 'k': factor = 1.e3; break;
case 'K': factor = 1.e3; break; // not technically correct but often used for k
case 'M': factor = 1.e6; break;
case 'G': factor = 1.e9; break;
case 'T': factor = 1.e12; break;

View File

@ -158,6 +158,7 @@ public:
if(saveAndExitRequested()) // via SIGTERM
return false;
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
// stop if it reached the maximum number of epochs
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
if(stopAfterEpochs > 0 && state_->epochs > stopAfterEpochs)
@ -167,6 +168,19 @@ public:
size_t stopAfterBatches = options_->get<size_t>("after-batches");
if(stopAfterBatches > 0 && state_->batches >= stopAfterBatches)
return false;
#endif
// get list of stopping criteria e.g. "10e,300Ku,20Gt" (10 epochs, 300,000 updates, 20 billion target labels)
// and stop for whatever criterion hits first.
std::vector<std::string> stoppingCriteria = utils::split(options_->get<std::string>("after"), ",");
for(auto stoppingCriterionString : stoppingCriteria) {
SchedulingParameter stoppingCriterion = SchedulingParameter::parse(stoppingCriterionString);
if(stoppingCriterion.n > 0) { // is any stopping criterion defined?
if(stoppingCriterion.unit == SchedulingUnit::epochs && state_->epochs > stoppingCriterion.n) return false;
if(stoppingCriterion.unit == SchedulingUnit::updates && state_->batches >= stoppingCriterion.n) return false;
if(stoppingCriterion.unit == SchedulingUnit::trgLabels && state_->labelsTotal >= stoppingCriterion.n) return false;
}
}
// stop if the first validator did not improve for a given number of checks
size_t stopAfterStalled = options_->get<size_t>("early-stopping");