mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge branch 'master' into pmaster
This commit is contained in:
commit
84a20f65a1
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Early stopping based on first, all, or any validation metrics via `--early-stopping-on`
|
||||
- Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
|
||||
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
|
||||
- Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off
|
||||
@ -41,6 +42,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Broken links to MNIST data sets
|
||||
|
||||
### Changed
|
||||
- Set REQUIRED_BIAS_ALIGNMENT = 16 in tensors/gpu/prod.cpp to avoid memory-misalignment on certain Ampere GPUs.
|
||||
- For BUILD_ARCH != native enable all intrinsics types by default, can be disabled like this: -DCOMPILE_AVX512=off
|
||||
- Moved FBGEMM pointer to commit c258054 for gcc 9.3+ fix
|
||||
- Change compile options a la -DCOMPILE_CUDA_SM35 to -DCOMPILE_KEPLER, -DCOMPILE_MAXWELL,
|
||||
|
@ -244,7 +244,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
"Tie all embedding layers and output layer");
|
||||
cli.add<bool>("--output-omit-bias",
|
||||
"Do not use a bias vector in decoder output layer");
|
||||
|
||||
|
||||
// Transformer options
|
||||
cli.add<int>("--transformer-heads",
|
||||
"Number of heads in multi-head attention (transformer)",
|
||||
@ -529,13 +529,13 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
"Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). "
|
||||
"After this many updates about 90% of the mass of the exponential average comes from these updates",
|
||||
100);
|
||||
cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
|
||||
cli.add<std::vector<std::string>>("--dynamic-gradient-scaling",
|
||||
"Re-scale gradient to have average gradient norm if (log) gradient norm diverges from average by arg1 sigmas. "
|
||||
"If arg2 = \"log\" the statistics are recorded for the log of the gradient norm else use plain norm")
|
||||
->implicit_val("2.f log");
|
||||
cli.add<bool>("--check-gradient-nan",
|
||||
cli.add<bool>("--check-gradient-nan",
|
||||
"Skip parameter update in case of NaNs in gradient");
|
||||
cli.add<bool>("--normalize-gradient",
|
||||
cli.add<bool>("--normalize-gradient",
|
||||
"Normalize gradient by multiplying with no. devices / total labels (not recommended and to be removed in the future)");
|
||||
|
||||
cli.add<std::vector<std::string>>("--train-embedder-rank",
|
||||
@ -574,6 +574,10 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
||||
cli.add<size_t>("--early-stopping",
|
||||
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
||||
10);
|
||||
cli.add<std::string>("--early-stopping-on",
|
||||
"Decide if early stopping should take into account first, all, or any validation metrics"
|
||||
"Possible values: first, all, any",
|
||||
"first");
|
||||
|
||||
// decoding options
|
||||
cli.add<size_t>("--beam-size,-b",
|
||||
@ -586,7 +590,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
||||
"Maximum target length as source length times factor",
|
||||
3);
|
||||
cli.add<float>("--word-penalty",
|
||||
"Subtract (arg * translation length) from translation score ");
|
||||
"Subtract (arg * translation length) from translation score");
|
||||
cli.add<bool>("--allow-unk",
|
||||
"Allow unknown words to appear in output");
|
||||
cli.add<bool>("--n-best",
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include "common/utils.h"
|
||||
#include "common/filesystem.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace marian {
|
||||
|
||||
bool ConfigValidator::has(const std::string& key) const {
|
||||
@ -129,6 +131,11 @@ void ConfigValidator::validateOptionsTraining() const {
|
||||
&& !get<std::vector<std::string>>("valid-sets").empty(),
|
||||
errorMsg);
|
||||
|
||||
// check if --early-stopping-on has proper value
|
||||
std::set<std::string> supportedStops = {"first", "all", "any"};
|
||||
ABORT_IF(supportedStops.find(get<std::string>("early-stopping-on")) == supportedStops.end(),
|
||||
"Supported options for --early-stopping-on are: first, all, any");
|
||||
|
||||
// validations for learning rate decaying
|
||||
ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
|
||||
|
||||
@ -145,7 +152,7 @@ void ConfigValidator::validateOptionsTraining() const {
|
||||
// validate ULR options
|
||||
ABORT_IF((has("ulr") && get<bool>("ulr") && (get<std::string>("ulr-query-vectors") == ""
|
||||
|| get<std::string>("ulr-keys-vectors") == "")),
|
||||
"ULR enablign requires query and keys vectors specified with --ulr-query-vectors and "
|
||||
"ULR requires query and keys vectors specified with --ulr-query-vectors and "
|
||||
"--ulr-keys-vectors option");
|
||||
|
||||
// validate model quantization
|
||||
|
@ -22,7 +22,7 @@ namespace gpu {
|
||||
// It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore,
|
||||
// if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a
|
||||
// custom epilogue kernel.
|
||||
static constexpr int REQUIRED_BIAS_ALIGNMENT = 8;
|
||||
static constexpr int REQUIRED_BIAS_ALIGNMENT = 16; // @TODO: MJD: changed this to 16 to avoid alignment error on A100. Seems to work fine.
|
||||
|
||||
// Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment
|
||||
int getAlignmentUpTo256(const void *ptr) {
|
||||
|
@ -28,7 +28,7 @@ private:
|
||||
// (regardless if it's the 1st or nth epoch and if it's a new or continued training),
|
||||
// which indicates the end of the training data stream from STDIN
|
||||
bool endOfStdin_{false}; // true at the end of the epoch if training from STDIN;
|
||||
|
||||
|
||||
// @TODO: figure out how to compute this with regard to updates as well, although maybe harder since no final value
|
||||
// determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
|
||||
float getScheduledLRDecayFactor(const TrainingState& state) const {
|
||||
@ -133,7 +133,7 @@ public:
|
||||
Scheduler(Ptr<Options> options, Ptr<TrainingState> state, Ptr<IMPIWrapper> mpi = nullptr)
|
||||
: options_(options), state_(state), mpi_(mpi),
|
||||
gradientNormAvgWindow_(options_->get<size_t>("gradient-norm-average-window", 100)) {
|
||||
|
||||
|
||||
// parse logical-epoch parameters
|
||||
auto logicalEpochStr = options->get<std::vector<std::string>>("logical-epoch", {"1e", "0"});
|
||||
ABORT_IF(logicalEpochStr.empty(), "Logical epoch information is missing?");
|
||||
@ -174,7 +174,7 @@ public:
|
||||
size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
|
||||
auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
|
||||
// if unit is labels, then account for the fact that our increment itself is not constant
|
||||
#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
|
||||
#if 1 // this seems to hurt convergence quite a bit compared to when updates is used
|
||||
if (mbWarmup.unit == SchedulingUnit::trgLabels)
|
||||
progressRatio = std::sqrt(progressRatio);
|
||||
#endif
|
||||
@ -207,7 +207,7 @@ public:
|
||||
if(saveAndExitRequested()) // via SIGTERM
|
||||
return false;
|
||||
|
||||
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
|
||||
#if 1 // @TODO: to be removed once we deprecate after-epochs and after-batches
|
||||
// stop if it reached the maximum number of epochs
|
||||
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
|
||||
if(stopAfterEpochs > 0 && calculateLogicalEpoch() > stopAfterEpochs)
|
||||
@ -231,10 +231,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// stop if the first validator did not improve for a given number of checks
|
||||
// stop if the first/all/any validators did not improve for a given number of checks
|
||||
size_t stopAfterStalled = options_->get<size_t>("early-stopping");
|
||||
if(stopAfterStalled > 0 && !validators_.empty()
|
||||
&& stalled() >= stopAfterStalled)
|
||||
if(stopAfterStalled > 0 && stalled() >= stopAfterStalled)
|
||||
return false;
|
||||
|
||||
// stop if data streaming from STDIN is stopped
|
||||
@ -297,12 +296,11 @@ public:
|
||||
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
|
||||
return;
|
||||
|
||||
bool firstValidator = true;
|
||||
size_t stalledPrev = stalled();
|
||||
for(auto validator : validators_) {
|
||||
if(!validator)
|
||||
continue;
|
||||
|
||||
size_t stalledPrev = validator->stalled();
|
||||
float value = 0;
|
||||
if(!mpi_ || mpi_->isMainProcess()) {
|
||||
// We run validation only in the main process, but this is risky with MPI.
|
||||
@ -330,34 +328,60 @@ public:
|
||||
if(mpi_) {
|
||||
// collect and broadcast validation result to all processes and bring validator up-to-date
|
||||
mpi_->bCast(&value, 1, IMPIWrapper::getDataType(&value));
|
||||
|
||||
|
||||
// @TODO: add function to validator?
|
||||
mpi_->bCast(&validator->stalled(), 1, IMPIWrapper::getDataType(&validator->stalled()));
|
||||
mpi_->bCast(&validator->lastBest(), 1, IMPIWrapper::getDataType(&validator->lastBest()));
|
||||
}
|
||||
|
||||
if(firstValidator)
|
||||
state_->validBest = value;
|
||||
|
||||
state_->validators[validator->type()]["last-best"] = validator->lastBest();
|
||||
state_->validators[validator->type()]["stalled"] = validator->stalled();
|
||||
|
||||
// notify training observers if the first validator did not improve
|
||||
if(firstValidator && validator->stalled() > stalledPrev)
|
||||
state_->newStalled(validator->stalled());
|
||||
firstValidator = false;
|
||||
}
|
||||
|
||||
// notify training observers about stalled validation
|
||||
size_t stalledNew = stalled();
|
||||
if(stalledNew > stalledPrev)
|
||||
state_->newStalled(stalledNew);
|
||||
|
||||
state_->validated = true;
|
||||
}
|
||||
|
||||
// Returns the proper number of stalled validation w.r.t. early-stopping-on
|
||||
size_t stalled() {
|
||||
std::string stopOn = options_->get<std::string>("early-stopping-on");
|
||||
if(stopOn == "any")
|
||||
return stalledMax();
|
||||
if(stopOn == "all")
|
||||
return stalledMin();
|
||||
return stalled1st();
|
||||
}
|
||||
|
||||
// Returns the number of stalled validations for the first validator
|
||||
size_t stalled1st() {
|
||||
if(!validators_.empty())
|
||||
if(validators_[0])
|
||||
return validators_[0]->stalled();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Returns the largest number of stalled validations across validators or 0 if there are no validators
|
||||
size_t stalledMax() {
|
||||
size_t max = 0;
|
||||
for(auto validator : validators_)
|
||||
if(validator && validator->stalled() > max)
|
||||
max = validator->stalled();
|
||||
return max;
|
||||
}
|
||||
|
||||
// Returns the lowest number of stalled validations across validators or 0 if there are no validators
|
||||
size_t stalledMin() {
|
||||
size_t min = std::numeric_limits<std::size_t>::max();
|
||||
for(auto validator : validators_)
|
||||
if(validator && validator->stalled() < min)
|
||||
min = validator->stalled();
|
||||
return min == std::numeric_limits<std::size_t>::max() ? 0 : min;
|
||||
}
|
||||
|
||||
void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
|
||||
update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg(), /*gradientNorm=*/0.f);
|
||||
}
|
||||
@ -397,8 +421,8 @@ public:
|
||||
|
||||
if(gradientNorm) {
|
||||
size_t range = std::min(gradientNormAvgWindow_, state_->batches);
|
||||
float alpha = 2.f / (float)(range + 1);
|
||||
|
||||
float alpha = 2.f / (float)(range + 1);
|
||||
|
||||
float delta = gradientNorm - state_->gradientNormAvg;
|
||||
state_->gradientNormAvg = state_->gradientNormAvg + alpha * delta;
|
||||
state_->gradientNormVar = (1.0f - alpha) * (state_->gradientNormVar + alpha * delta * delta);
|
||||
@ -440,7 +464,7 @@ public:
|
||||
formatLogicalEpoch(),
|
||||
state_->batches,
|
||||
utils::withCommas(state_->samplesEpoch),
|
||||
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
|
||||
formatLoss(lossType, dispLabelCounts, batchLabels, state_),
|
||||
timer_.elapsed(),
|
||||
state_->wordsDisp / timer_.elapsed(),
|
||||
state_->gradientNormAvg);
|
||||
@ -627,7 +651,8 @@ public:
|
||||
|
||||
if(options_->get<bool>("lr-decay-repeat-warmup")) {
|
||||
LOG(info, "Restarting learning rate warmup");
|
||||
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
|
||||
state.warmupStart.n = state.getProgressIn(
|
||||
SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,8 +43,6 @@ public:
|
||||
size_t stalled{0};
|
||||
// The largest number of stalled validations so far
|
||||
size_t maxStalled{0};
|
||||
// Last best validation score
|
||||
float validBest{0.f};
|
||||
std::string validator;
|
||||
// List of validators
|
||||
YAML::Node validators;
|
||||
@ -217,7 +215,6 @@ public:
|
||||
|
||||
stalled = config["stalled"].as<size_t>();
|
||||
maxStalled = config["stalled-max"].as<size_t>();
|
||||
validBest = config["valid-best"].as<float>();
|
||||
validator = config["validator"].as<std::string>();
|
||||
validators = config["validators"];
|
||||
reset = config["reset"].as<bool>();
|
||||
@ -259,7 +256,6 @@ public:
|
||||
|
||||
config["stalled"] = stalled;
|
||||
config["stalled-max"] = maxStalled;
|
||||
config["valid-best"] = validBest;
|
||||
config["validator"] = validator;
|
||||
config["validators"] = validators;
|
||||
config["reset"] = reset;
|
||||
|
@ -447,7 +447,7 @@ SacreBleuValidator::SacreBleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Optio
|
||||
|
||||
ABORT_IF(computeChrF_ && useWordIds_, "Cannot compute ChrF on word ids"); // should not really happen, but let's check.
|
||||
|
||||
if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
|
||||
if(computeChrF_) // according to SacreBLEU implementation this is the default for ChrF,
|
||||
order_ = 6; // we compute stats over character ngrams up to length 6
|
||||
|
||||
// @TODO: remove, only used for saving?
|
||||
@ -613,12 +613,12 @@ void SacreBleuValidator::updateStats(std::vector<float>& stats,
|
||||
LOG_VALID_ONCE(info, "First sentence's tokens as scored:");
|
||||
LOG_VALID_ONCE(info, " Hyp: {}", utils::join(decode(cand, /*addEOS=*/false)));
|
||||
LOG_VALID_ONCE(info, " Ref: {}", utils::join(decode(ref, /*addEOS=*/false)));
|
||||
|
||||
|
||||
if(useWordIds_)
|
||||
updateStats(stats, cand, ref);
|
||||
else
|
||||
updateStats(stats, decode(cand, /*addEOS=*/false), decode(ref, /*addEOS=*/false));
|
||||
|
||||
|
||||
}
|
||||
|
||||
// Re-implementation of BLEU metric from SacreBLEU
|
||||
@ -627,7 +627,7 @@ float SacreBleuValidator::calcBLEU(const std::vector<float>& stats) {
|
||||
for(int i = 0; i < order_; ++i) {
|
||||
float commonNgrams = stats[statsPerOrder * i + 0];
|
||||
float hypothesesNgrams = stats[statsPerOrder * i + 1];
|
||||
|
||||
|
||||
if(commonNgrams == 0.f)
|
||||
return 0.f;
|
||||
logbleu += std::log(commonNgrams) - std::log(hypothesesNgrams);
|
||||
@ -653,7 +653,7 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
|
||||
float commonNgrams = stats[statsPerOrder * i + 0];
|
||||
float hypothesesNgrams = stats[statsPerOrder * i + 1];
|
||||
float referencesNgrams = stats[statsPerOrder * i + 2];
|
||||
|
||||
|
||||
if(hypothesesNgrams > 0 && referencesNgrams > 0) {
|
||||
avgPrecision += commonNgrams / hypothesesNgrams;
|
||||
avgRecall += commonNgrams / referencesNgrams;
|
||||
@ -666,10 +666,10 @@ float SacreBleuValidator::calcChrF(const std::vector<float>& stats) {
|
||||
|
||||
avgPrecision /= effectiveOrder;
|
||||
avgRecall /= effectiveOrder;
|
||||
|
||||
|
||||
if(avgPrecision + avgRecall == 0.f)
|
||||
return 0.f;
|
||||
|
||||
|
||||
auto betaSquare = beta * beta;
|
||||
auto score = (1.f + betaSquare) * (avgPrecision * avgRecall) / ((betaSquare * avgPrecision) + avgRecall);
|
||||
return score * 100.f; // we multiply by 100 which is usually not done for ChrF, but this makes it more comparable to BLEU
|
||||
|
@ -352,7 +352,7 @@ protected:
|
||||
private:
|
||||
const std::string metric_; // allowed values are: bleu, bleu-detok (same as bleu), bleu-segmented, chrf
|
||||
bool computeChrF_{ false }; // should we compute ChrF instead of BLEU (BLEU by default)?
|
||||
|
||||
|
||||
size_t order_{ 4 }; // 4-grams for BLEU by default
|
||||
static const size_t statsPerOrder = 3; // 0: common ngrams, 1: candidate ngrams, 2: reference ngrams
|
||||
bool useWordIds_{ false }; // compute BLEU score by matching numeric segment ids
|
||||
|
Loading…
Reference in New Issue
Block a user