mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 15896: Add --after N option to supersede --after-batches and --after-epochs
Replace `--after-batches N` and `--after-epochs N` with `--after Nu/Ne` which allows to specify updates, epochs, target labels with units, e.g.: * `--after 30Gt` or `--after 50ku` or `--after 10e` * Can also combine multiple criteria: `--after 30Gt,50ku,10e` and will stop when whichever hits first Changes default `cost-type` from `ce-mean` to `ce-sum` and turns `display-label-counts` on by default.
This commit is contained in:
parent
ae866af035
commit
160b36cec8
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Add --after option which is meant to replace --after-batches and --after-epochs and can take label based criteria
|
||||
- Add --transformer-postprocess-top option to enable correctly normalized prenorm behavior
|
||||
- Add --task transformer-base-prenorm and --task transformer-big-prenorm
|
||||
- Turing and Ampere GPU optimisation support, if the CUDA version supports it.
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit cdad78089484d7817d91c803d6fc7049328e20db
|
||||
Subproject commit 75977846abfccd29941e4bfd3c615a111599f7f4
|
@ -10,9 +10,12 @@ namespace marian {
|
||||
namespace cli {
|
||||
|
||||
// clang-format off
|
||||
const std::unordered_set<std::string> DEPRECIATED_OPTIONS = {
|
||||
const std::unordered_set<std::string> DEPRECATED_OPTIONS = {
|
||||
"version",
|
||||
"special-vocab"
|
||||
"special-vocab",
|
||||
// @TODO: uncomment once we actually deprecate them.
|
||||
// "after-batches",
|
||||
// "after-epochs"
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
@ -177,7 +180,7 @@ void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority prio
|
||||
if(cmdOptions.count(key))
|
||||
continue;
|
||||
// Skip options that might exist in config files generated by older versions of Marian
|
||||
if(DEPRECIATED_OPTIONS.count(key))
|
||||
if(DEPRECATED_OPTIONS.count(key))
|
||||
continue;
|
||||
|
||||
// Check if an incoming option has been defined in CLI
|
||||
|
@ -347,7 +347,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
auto previous_group = cli.switchGroup("Training options");
|
||||
// clang-format off
|
||||
cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
|
||||
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
|
||||
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-sum");
|
||||
cli.add<std::string>("--multi-loss-type",
|
||||
"How to accumulate multi-objective losses: sum, scaled, mean", "sum");
|
||||
cli.add<bool>("--unlikelihood-loss",
|
||||
@ -375,17 +375,24 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
10000000);
|
||||
#endif
|
||||
// scheduling options
|
||||
|
||||
// @TODO: these should be re-defined as aliases for `--after` but the current frame work matches on value, so not doable.
|
||||
cli.add<size_t>("--after-epochs,-e",
|
||||
"Finish after this many epochs, 0 is infinity");
|
||||
"Finish after this many epochs, 0 is infinity (deprecated, '--after-epochs N' corresponds to '--after Ne')"); // @TODO: replace with alias
|
||||
cli.add<size_t>("--after-batches",
|
||||
"Finish after this many batch updates, 0 is infinity");
|
||||
"Finish after this many batch updates, 0 is infinity (deprecated, '--after-batches N' corresponds to '--after Nu')"); // @TODO: replace with alias
|
||||
|
||||
cli.add<std::string>("--after,-a",
|
||||
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
|
||||
"0e");
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
|
||||
"Display information every arg updates (append 't' for every arg target labels)",
|
||||
"1000u");
|
||||
cli.add<size_t>("--disp-first",
|
||||
"Display information for the first arg updates");
|
||||
cli.add<bool>("--disp-label-counts",
|
||||
"Display label counts when logging loss progress");
|
||||
"Display label counts when logging loss progress",
|
||||
true);
|
||||
// cli.add<int>("--disp-label-index",
|
||||
// "Display label counts based on i-th input stream (-1 is last)", -1);
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
|
||||
@ -901,7 +908,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
||||
|
||||
cli::mode ConfigParser::getMode() const { return mode_; }
|
||||
|
||||
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
|
||||
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
||||
cmdLine_ = escapeCmdLine(argc,argv);
|
||||
|
||||
// parse command-line options and fill wrapped YAML config
|
||||
|
@ -415,6 +415,7 @@ double parseNumber(std::string param) {
|
||||
if(!param.empty() && param.back() >= 'A') {
|
||||
switch(param.back()) {
|
||||
case 'k': factor = 1.e3; break;
|
||||
case 'K': factor = 1.e3; break; // not technically correct but often used for k
|
||||
case 'M': factor = 1.e6; break;
|
||||
case 'G': factor = 1.e9; break;
|
||||
case 'T': factor = 1.e12; break;
|
||||
|
@ -158,6 +158,7 @@ public:
|
||||
if(saveAndExitRequested()) // via SIGTERM
|
||||
return false;
|
||||
|
||||
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
|
||||
// stop if it reached the maximum number of epochs
|
||||
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
|
||||
if(stopAfterEpochs > 0 && state_->epochs > stopAfterEpochs)
|
||||
@ -167,6 +168,19 @@ public:
|
||||
size_t stopAfterBatches = options_->get<size_t>("after-batches");
|
||||
if(stopAfterBatches > 0 && state_->batches >= stopAfterBatches)
|
||||
return false;
|
||||
#endif
|
||||
|
||||
// get list of stopping criteria e.g. "10e,300Ku,20Gt" (10 epochs, 300,000 updates, 20 billion target labels)
|
||||
// and stop for whatever criterion hits first.
|
||||
std::vector<std::string> stoppingCriteria = utils::split(options_->get<std::string>("after"), ",");
|
||||
for(auto stoppingCriterionString : stoppingCriteria) {
|
||||
SchedulingParameter stoppingCriterion = SchedulingParameter::parse(stoppingCriterionString);
|
||||
if(stoppingCriterion.n > 0) { // is any stopping criterion defined?
|
||||
if(stoppingCriterion.unit == SchedulingUnit::epochs && state_->epochs > stoppingCriterion.n) return false;
|
||||
if(stoppingCriterion.unit == SchedulingUnit::updates && state_->batches >= stoppingCriterion.n) return false;
|
||||
if(stoppingCriterion.unit == SchedulingUnit::trgLabels && state_->labelsTotal >= stoppingCriterion.n) return false;
|
||||
}
|
||||
}
|
||||
|
||||
// stop if the first validator did not improve for a given number of checks
|
||||
size_t stopAfterStalled = options_->get<size_t>("early-stopping");
|
||||
|
Loading…
Reference in New Issue
Block a user