mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-04 02:28:15 +03:00
Merged PR 27051: Add an option for completely resetting validation metrics
Added `--valid-reset-all` that works as `--valid-reset-stalled` but it also resets last best saved validation metrics, which is useful for when the validation sets change for continued training. Added new regression test: https://github.com/marian-nmt/marian-regression-tests/pull/89
This commit is contained in:
parent
b7205fc0b0
commit
ee50d4aaea
@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
- Fused inplace-dropout in FFN layer in Transformer
|
- Fused inplace-dropout in FFN layer in Transformer
|
||||||
- `--force-decode` option for marian-decoder
|
- `--force-decode` option for marian-decoder
|
||||||
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
|
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
|
||||||
|
- `--valid-reset-all` option
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- Make concat factors not break old vector implementation
|
- Make concat factors not break old vector implementation
|
||||||
|
@ -595,7 +595,10 @@ stages:
|
|||||||
|
|
||||||
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
|
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
|
||||||
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
|
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
|
||||||
- bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
|
# Installing libunwind-dev fixes a bug in 2204 (the libunwind-14 and libunwind-dev conflict)
|
||||||
|
- bash: |
|
||||||
|
sudo apt-get install -y libunwind-dev
|
||||||
|
sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
|
||||||
displayName: Install packages
|
displayName: Install packages
|
||||||
|
|
||||||
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
||||||
|
@ -269,7 +269,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||||||
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
||||||
cli.add<int>("--transformer-dim-ffn",
|
cli.add<int>("--transformer-dim-ffn",
|
||||||
"Size of position-wise feed-forward network (transformer)",
|
"Size of position-wise feed-forward network (transformer)",
|
||||||
2048);
|
2048);
|
||||||
cli.add<int>("--transformer-decoder-dim-ffn",
|
cli.add<int>("--transformer-decoder-dim-ffn",
|
||||||
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
|
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
|
||||||
0);
|
0);
|
||||||
@ -591,7 +591,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||||||
"Multiple metrics can be specified",
|
"Multiple metrics can be specified",
|
||||||
{"cross-entropy"});
|
{"cross-entropy"});
|
||||||
cli.add<bool>("--valid-reset-stalled",
|
cli.add<bool>("--valid-reset-stalled",
|
||||||
"Reset all stalled validation metrics when the training is restarted");
|
"Reset stalled validation metrics when the training is restarted");
|
||||||
|
cli.add<bool>("--valid-reset-all",
|
||||||
|
"Reset all validation metrics when the training is restarted");
|
||||||
cli.add<size_t>("--early-stopping",
|
cli.add<size_t>("--early-stopping",
|
||||||
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
||||||
10);
|
10);
|
||||||
|
@ -494,12 +494,17 @@ public:
|
|||||||
state_->wordsDisp = 0;
|
state_->wordsDisp = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(options_->get<bool>("valid-reset-stalled")) {
|
if(options_->get<bool>("valid-reset-stalled") || options_->get<bool>("valid-reset-all")) {
|
||||||
state_->stalled = 0;
|
state_->stalled = 0;
|
||||||
state_->maxStalled = 0;
|
state_->maxStalled = 0;
|
||||||
for(const auto& validator : validators_) {
|
for(const auto& validator : validators_) {
|
||||||
if(state_->validators[validator->type()])
|
if(state_->validators[validator->type()]) {
|
||||||
|
// reset the number of stalled validations, e.g. when the validation set is the same
|
||||||
state_->validators[validator->type()]["stalled"] = 0;
|
state_->validators[validator->type()]["stalled"] = 0;
|
||||||
|
// reset last best results as well, e.g. when the validation set changes
|
||||||
|
if(options_->get<bool>("valid-reset-all"))
|
||||||
|
state_->validators[validator->type()]["last-best"] = validator->initScore();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -512,10 +517,10 @@ public:
|
|||||||
if(mpi_->isMainProcess())
|
if(mpi_->isMainProcess())
|
||||||
if(filesystem::exists(nameYaml))
|
if(filesystem::exists(nameYaml))
|
||||||
yamlStr = io::InputFileStream(nameYaml).readToString();
|
yamlStr = io::InputFileStream(nameYaml).readToString();
|
||||||
|
|
||||||
if(mpi_)
|
if(mpi_)
|
||||||
mpi_->bCast(yamlStr);
|
mpi_->bCast(yamlStr);
|
||||||
|
|
||||||
loadFromString(yamlStr);
|
loadFromString(yamlStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user