Merged PR 18612: Early stopping on first, all, or any validation metrics

Adds `--early-stopping-on first|all|any` allowing to decide if early stopping should take into account only first, all, or any validation metrics.

Feature request: https://github.com/marian-nmt/marian-dev/issues/850
Regression tests: https://github.com/marian-nmt/marian-regression-tests/pull/79
This commit is contained in:
Roman Grundkiewicz 2021-04-26 11:51:43 +00:00
parent 3e51ff3872
commit 49e379bba5
8 changed files with 74 additions and 41 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Early stopping based on first, all, or any validation metrics via `--early-stopping-on`
- Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
- Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off

@ -1 +1 @@
Subproject commit 7d612ca5e4b27a76f92584dad76d240e34f216d0
Subproject commit 1afd4eb1014ac451c6a3d6f9b5d34c322902e624

View File

@ -244,7 +244,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Tie all embedding layers and output layer");
cli.add<bool>("--output-omit-bias",
"Do not use a bias vector in decoder output layer");
// Transformer options
cli.add<int>("--transformer-heads",
"Number of heads in multi-head attention (transformer)",
@ -529,13 +529,13 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). "
"After this many updates about 90% of the mass of the exponential average comes from these updates",
100);
cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
"Re-scale gradient to have average gradient norm if (log) gradient norm diverges from average by arg1 sigmas. "
"If arg2 = \"log\" the statistics are recorded for the log of the gradient norm else use plain norm")
->implicit_val("2.f log");
cli.add<bool>("--check-gradient-nan",
cli.add<bool>("--check-gradient-nan",
"Skip parameter update in case of NaNs in gradient");
cli.add<bool>("--normalize-gradient",
cli.add<bool>("--normalize-gradient",
"Normalize gradient by multiplying with no. devices / total labels (not recommended and to be removed in the future)");
cli.add<std::vector<std::string>>("--train-embedder-rank",
@ -574,6 +574,10 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);
cli.add<std::string>("--early-stopping-on",
"Decide if early stopping should take into account first, all, or any validation metrics"
"Possible values: first, all, any",
"first");
// decoding options
cli.add<size_t>("--beam-size,-b",
@ -586,7 +590,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Maximum target length as source length times factor",
3);
cli.add<float>("--word-penalty",
"Subtract (arg * translation length) from translation score ");
"Subtract (arg * translation length) from translation score");
cli.add<bool>("--allow-unk",
"Allow unknown words to appear in output");
cli.add<bool>("--n-best",

View File

@ -4,6 +4,8 @@
#include "common/utils.h"
#include "common/filesystem.h"
#include <set>
namespace marian {
bool ConfigValidator::has(const std::string& key) const {
@ -129,6 +131,11 @@ void ConfigValidator::validateOptionsTraining() const {
&& !get<std::vector<std::string>>("valid-sets").empty(),
errorMsg);
// check if --early-stopping-on has proper value
std::set<std::string> supportedStops = {"first", "all", "any"};
ABORT_IF(supportedStops.find(get<std::string>("early-stopping-on")) == supportedStops.end(),
"Supported options for --early-stopping-on are: first, all, any");
// validations for learning rate decaying
ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
@ -145,7 +152,7 @@ void ConfigValidator::validateOptionsTraining() const {
// validate ULR options
ABORT_IF((has("ulr") && get<bool>("ulr") && (get<std::string>("ulr-query-vectors") == ""
|| get<std::string>("ulr-keys-vectors") == "")),
"ULR enablign requires query and keys vectors specified with --ulr-query-vectors and "
"ULR requires query and keys vectors specified with --ulr-query-vectors and "
"--ulr-keys-vectors option");
// validate model quantization

View File

@ -28,7 +28,7 @@ private:
// (regardless if it's the 1st or nth epoch and if it's a new or continued training),
// which indicates the end of the training data stream from STDIN
bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN;
// @TODO: figure out how to compute this with regard to updates as well, although maybe harder since no final value
// determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
float getScheduledLRDecayFactor(const TrainingState& state) const {
@ -133,7 +133,7 @@ public:
Scheduler(Ptr<Options> options, Ptr<TrainingState> state, Ptr<IMPIWrapper> mpi = nullptr)
: options_(options), state_(state), mpi_(mpi),
gradientNormAvgWindow_(options_->get<size_t>("gradient-norm-average-window", 100)) {
// parse logical-epoch parameters
auto logicalEpochStr = options->get<std::vector<std::string>>("logical-epoch", {"1e", "0"});
ABORT_IF(logicalEpochStr.empty(), "Logical epoch information is missing?");
@ -174,7 +174,7 @@ public:
size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
// if unit is labels, then account for the fact that our increment itself is not constant
#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
if (mbWarmup.unit == SchedulingUnit::trgLabels)
progressRatio = std::sqrt(progressRatio);
#endif
@ -207,7 +207,7 @@ public:
if(saveAndExitRequested()) // via SIGTERM
return false;
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
// stop if it reached the maximum number of epochs
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
if(stopAfterEpochs > 0 && calculateLogicalEpoch() > stopAfterEpochs)
@ -231,10 +231,9 @@ public:
}
}
// stop if the first validator did not improve for a given number of checks
// stop if the first/all/any validators did not improve for a given number of checks
size_t stopAfterStalled = options_->get<size_t>("early-stopping");
if(stopAfterStalled > 0 && !validators_.empty()
&& stalled() >= stopAfterStalled)
if(stopAfterStalled > 0 && stalled() >= stopAfterStalled)
return false;
// stop if data streaming from STDIN is stopped
@ -297,12 +296,11 @@ public:
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
return;
bool firstValidator = true;
size_t stalledPrev = stalled();
for(auto validator : validators_) {
if(!validator)
continue;
size_t stalledPrev = validator->stalled();
float value = 0;
if(!mpi_ || mpi_->isMainProcess()) {
// We run validation only in the main process, but this is risky with MPI.
@ -330,34 +328,60 @@ public:
if(mpi_) {
// collect and broadcast validation result to all processes and bring validator up-to-date
mpi_->bCast(&value, 1, IMPIWrapper::getDataType(&value));
// @TODO: add function to validator?
mpi_->bCast(&validator->stalled(), 1, IMPIWrapper::getDataType(&validator->stalled()));
mpi_->bCast(&validator->lastBest(), 1, IMPIWrapper::getDataType(&validator->lastBest()));
}
if(firstValidator)
state_->validBest = value;
state_->validators[validator->type()]["last-best"] = validator->lastBest();
state_->validators[validator->type()]["stalled"] = validator->stalled();
// notify training observers if the first validator did not improve
if(firstValidator && validator->stalled() > stalledPrev)
state_->newStalled(validator->stalled());
firstValidator = false;
}
// notify training observers about stalled validation
size_t stalledNew = stalled();
if(stalledNew > stalledPrev)
state_->newStalled(stalledNew);
state_->validated = true;
}
// Returns the proper number of stalled validation w.r.t. early-stopping-on
size_t stalled() {
std::string stopOn = options_->get<std::string>("early-stopping-on");
if(stopOn == "any")
return stalledMax();
if(stopOn == "all")
return stalledMin();
return stalled1st();
}
// Returns the number of stalled validations for the first validator
size_t stalled1st() {
if(!validators_.empty())
if(validators_[0])
return validators_[0]->stalled();
return 0;
}
// Returns the largest number of stalled validations across validators or 0 if there are no validators
size_t stalledMax() {
size_t max = 0;
for(auto validator : validators_)
if(validator && validator->stalled() > max)
max = validator->stalled();
return max;
}
// Returns the lowest number of stalled validations across validators or 0 if there are no validators
size_t stalledMin() {
size_t min = std::numeric_limits<std::size_t>::max();
for(auto validator : validators_)
if(validator && validator->stalled() < min)
min = validator->stalled();
return min == std::numeric_limits<std::size_t>::max() ? 0 : min;
}
void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg(), /*gradientNorm=*/0.f);
}
@ -397,8 +421,8 @@ public:
if(gradientNorm) {
size_t range = std::min(gradientNormAvgWindow_, state_->batches);
float alpha = 2.f / (float)(range + 1);
float alpha = 2.f / (float)(range + 1);
float delta = gradientNorm - state_->gradientNormAvg;
state_->gradientNormAvg = state_->gradientNormAvg + alpha * delta;
state_->gradientNormVar = (1.0f - alpha) * (state_->gradientNormVar + alpha * delta * delta);
@ -440,7 +464,7 @@ public:
formatLogicalEpoch(),
state_->batches,
utils::withCommas(state_->samplesEpoch),
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
timer_.elapsed(),
state_->wordsDisp / timer_.elapsed(),
state_->gradientNormAvg);
@ -627,7 +651,8 @@ public:
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
state.warmupStart.n = state.getProgressIn(
SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
}

View File

@ -43,8 +43,6 @@ public:
size_t stalled{0};
// The largest number of stalled validations so far
size_t maxStalled{0};
// Last best validation score
float validBest{0.f};
std::string validator;
// List of validators
YAML::Node validators;
@ -217,7 +215,6 @@ public:
stalled = config["stalled"].as<size_t>();
maxStalled = config["stalled-max"].as<size_t>();
validBest = config["valid-best"].as<float>();
validator = config["validator"].as<std::string>();
validators = config["validators"];
reset = config["reset"].as<bool>();
@ -259,7 +256,6 @@ public:
config["stalled"] = stalled;
config["stalled-max"] = maxStalled;
config["valid-best"] = validBest;
config["validator"] = validator;
config["validators"] = validators;
config["reset"] = reset;

View File

@ -447,7 +447,7 @@ SacreBleuValidator::SacreBleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Optio
ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check.
if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
order_ = 6; // we compute stats over character ngrams up to length 6
// @TODO: remove, only used for saving?
@ -613,12 +613,12 @@ void SacreBleuValidator::updateStats(std::vector<float>& stats,
LOG_VALID_ONCE(info, "First sentence's tokens as scored:");
LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false)));
LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false)));
if(useWordIds_)
updateStats(stats, cand, ref);
else
updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false));
}
// Re-implementation of BLEU metric from SacreBLEU
@ -627,7 +627,7 @@ float SacreBleuValidator::calcBLEU(const std::vector<float>& stats) {
for(int i = 0; i < order_; ++i) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
if(commonNgrams == 0.f)
return 0.f;
logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams);
@ -653,7 +653,7 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
float commonNgrams = stats[statsPerOrder * i + 0];
float hypothesesNgrams = stats[statsPerOrder * i + 1];
float referencesNgrams = stats[statsPerOrder * i + 2];
if(hypothesesNgrams > 0 && referencesNgrams > 0) {
avgPrecision += commonNgrams / hypothesesNgrams;
avgRecall += commonNgrams / referencesNgrams;
@ -666,10 +666,10 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
avgPrecision /= effectiveOrder;
avgRecall /= effectiveOrder;
if(avgPrecision + avgRecall == 0.f)
return 0.f;
auto betaSquare = beta * beta;
auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall);
return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU

View File

@ -352,7 +352,7 @@ protected:
private:
const std::string metric_; // allowed values are: bleu, bleu-detok (same as bleu), bleu-segmented, chrf
bool computeChrF_{ false }; // should we compute ChrF instead of BLEU (BLEU by default)?
size_t order_{ 4 }; // 4-grams for BLEU by default
static const size_t statsPerOrder = 3; // 0: common ngrams, 1: candidate ngrams, 2: reference ngrams
bool useWordIds_{ false }; // compute BLEU score by matching numeric segment ids