Merged PR 27051: Add an option for completely resetting validation metrics

Added `--valid-reset-all` that works as `--valid-reset-stalled` but it also resets last best saved validation metrics, which is useful for when the validation sets change for continued training.

Added new regression test: https://github.com/marian-nmt/marian-regression-tests/pull/89
This commit is contained in:
Roman Grundkiewicz 2022-12-20 17:56:10 +00:00
parent b7205fc0b0
commit ee50d4aaea
5 changed files with 19 additions and 8 deletions

View File

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fused inplace-dropout in FFN layer in Transformer
- `--force-decode` option for marian-decoder
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
- `--valid-reset-all` option
### Fixed
- Make concat factors not break old vector implementation

View File

@ -1 +1 @@
v1.11.14
v1.11.15

View File

@ -595,7 +595,10 @@ stages:
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
- bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
# Installing libunwind-dev fixes a bug in 2204 (the libunwind-14 and libunwind-dev conflict)
- bash: |
sudo apt-get install -y libunwind-dev
sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
displayName: Install packages
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html

View File

@ -269,7 +269,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)",
2048);
2048);
cli.add<int>("--transformer-decoder-dim-ffn",
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
0);
@ -591,7 +591,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Multiple metrics can be specified",
{"cross-entropy"});
cli.add<bool>("--valid-reset-stalled",
"Reset all stalled validation metrics when the training is restarted");
"Reset stalled validation metrics when the training is restarted");
cli.add<bool>("--valid-reset-all",
"Reset all validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);

View File

@ -494,12 +494,17 @@ public:
state_->wordsDisp = 0;
}
if(options_->get<bool>("valid-reset-stalled")) {
if(options_->get<bool>("valid-reset-stalled") || options_->get<bool>("valid-reset-all")) {
state_->stalled = 0;
state_->maxStalled = 0;
for(const auto& validator : validators_) {
if(state_->validators[validator->type()])
if(state_->validators[validator->type()]) {
// reset the number of stalled validations, e.g. when the validation set is the same
state_->validators[validator->type()]["stalled"] = 0;
// reset last best results as well, e.g. when the validation set changes
if(options_->get<bool>("valid-reset-all"))
state_->validators[validator->type()]["last-best"] = validator->initScore();
}
}
}
@ -512,10 +517,10 @@ public:
if(mpi_->isMainProcess())
if(filesystem::exists(nameYaml))
yamlStr = io::InputFileStream(nameYaml).readToString();
if(mpi_)
mpi_->bCast(yamlStr);
loadFromString(yamlStr);
}