Cherry picked cleaning/refeactoring patches (#905)

Cherry-picked updates from pull request #457

Co-authored-by: Mateusz Chudyk <mateuszchudyk@gmail.com>
This commit is contained in:
Roman Grundkiewicz 2022-01-28 14:16:41 +00:00 committed by GitHub
parent 71b5454b9e
commit 07c39c7d76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 113 additions and 108 deletions

View File

@ -113,10 +113,10 @@ std::string CLIWrapper::switchGroup(std::string name) {
return name; return name;
} }
void CLIWrapper::parse(int argc, char **argv) { void CLIWrapper::parse(int argc, char** argv) {
try { try {
app_->parse(argc, argv); app_->parse(argc, argv);
} catch(const CLI::ParseError &e) { } catch(const CLI::ParseError& e) {
exit(app_->exit(e)); exit(app_->exit(e));
} }
@ -182,6 +182,13 @@ void CLIWrapper::parseAliases() {
} }
} }
std::string CLIWrapper::keyName(const std::string& args) const {
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) { void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
auto cmdOptions = getParsedOptionNames(); auto cmdOptions = getParsedOptionNames();
// Keep track of unrecognized options from the provided config // Keep track of unrecognized options from the provided config
@ -276,7 +283,7 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
for(auto const &it : options_) for(auto const &it : options_)
keys.push_back(it.first); keys.push_back(it.first);
// sort option names by creation index // sort option names by creation index
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) { sort(keys.begin(), keys.end(), [this](const std::string& a, const std::string& b) {
return options_.at(a).idx < options_.at(b).idx; return options_.at(a).idx < options_.at(b).idx;
}); });
return keys; return keys;

View File

@ -44,7 +44,7 @@ struct CLIAliasTuple {
class CLIFormatter : public CLI::Formatter { class CLIFormatter : public CLI::Formatter {
public: public:
CLIFormatter(size_t columnWidth, size_t screenWidth); CLIFormatter(size_t columnWidth, size_t screenWidth);
virtual std::string make_option_desc(const CLI::Option *) const override; virtual std::string make_option_desc(const CLI::Option*) const override;
private: private:
size_t screenWidth_{0}; size_t screenWidth_{0};
@ -85,12 +85,7 @@ private:
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from // Extract option name from a comma-separated list of long and short options, e.g. 'help' from
// '--help,-h' // '--help,-h'
std::string keyName(const std::string &args) const { std::string keyName(const std::string &args) const;
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
// Get names of options passed via command-line // Get names of options passed via command-line
std::unordered_set<std::string> getParsedOptionNames() const; std::unordered_set<std::string> getParsedOptionNames() const;
@ -134,7 +129,7 @@ public:
* @return Option object * @return Option object
*/ */
template <typename T> template <typename T>
CLI::Option *add(const std::string &args, const std::string &help, T val) { CLI::Option* add(const std::string& args, const std::string& help, T val) {
return addOption<T>(keyName(args), return addOption<T>(keyName(args),
args, args,
help, help,
@ -159,7 +154,7 @@ public:
* @TODO: require to always state the default value creating the parser as this will be clearer * @TODO: require to always state the default value creating the parser as this will be clearer
*/ */
template <typename T> template <typename T>
CLI::Option *add(const std::string &args, const std::string &help) { CLI::Option* add(const std::string& args, const std::string& help) {
return addOption<T>(keyName(args), return addOption<T>(keyName(args),
args, args,
help, help,
@ -206,7 +201,7 @@ public:
std::string switchGroup(std::string name = ""); std::string switchGroup(std::string name = "");
// Parse command-line arguments. Handles --help and --version options // Parse command-line arguments. Handles --help and --version options
void parse(int argc, char **argv); void parse(int argc, char** argv);
/** /**
* @brief Expand aliases based on arguments parsed with parse(int, char**) * @brief Expand aliases based on arguments parsed with parse(int, char**)
@ -240,11 +235,12 @@ public:
std::string dumpConfig(bool skipUnmodified = false) const; std::string dumpConfig(bool skipUnmodified = false) const;
private: private:
template <typename T, template <typename T>
// options with numeric and string-like values using EnableIfNumbericOrString = CLI::enable_if_t<!CLI::is_bool<T>::value
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value, && !CLI::is_vector<T>::value, CLI::detail::enabler>;
CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key, template <typename T, EnableIfNumbericOrString<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args, const std::string &args,
const std::string &help, const std::string &help,
T val, T val,
@ -261,7 +257,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) { CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine; options_[key].priority = cli::OptionPriority::CommandLine;
// get variable associated with the option // get variable associated with the option
auto &var = options_[key].var->as<T>(); auto& var = options_[key].var->as<T>();
// store parser result in var // store parser result in var
auto ret = CLI::detail::lexical_cast(res[0], var); auto ret = CLI::detail::lexical_cast(res[0], var);
// update YAML entry // update YAML entry
@ -288,10 +284,11 @@ private:
return options_[key].opt; return options_[key].opt;
} }
template <typename T, template <typename T>
// options with vector values using EnableIfVector = CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler>;
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key, template <typename T, EnableIfVector<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args, const std::string &args,
const std::string &help, const std::string &help,
T val, T val,
@ -308,7 +305,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) { CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine; options_[key].priority = cli::OptionPriority::CommandLine;
// get vector variable associated with the option // get vector variable associated with the option
auto &vec = options_[key].var->as<T>(); auto& vec = options_[key].var->as<T>();
vec.clear(); vec.clear();
bool ret = true; bool ret = true;
// handle '[]' as an empty vector // handle '[]' as an empty vector
@ -316,7 +313,7 @@ private:
ret = true; ret = true;
} else { } else {
// populate the vector with parser results // populate the vector with parser results
for(const auto &a : res) { for(const auto& a : res) {
vec.emplace_back(); vec.emplace_back();
ret &= CLI::detail::lexical_cast(a, vec.back()); ret &= CLI::detail::lexical_cast(a, vec.back());
} }
@ -345,10 +342,11 @@ private:
return options_[key].opt; return options_[key].opt;
} }
template <typename T, template <typename T>
// options with boolean values, called flags in CLI11 using EnableIfBoolean = CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler>;
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key, template <typename T, EnableIfBoolean<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args, const std::string &args,
const std::string &help, const std::string &help,
T val, T val,

View File

@ -107,7 +107,7 @@ private:
* @param mode change the set of available command-line options, e.g. training, translation, etc. * @param mode change the set of available command-line options, e.g. training, translation, etc.
* @param validate validate parsed options and abort on failure * @param validate validate parsed options and abort on failure
* *
* @return parsed otions * @return parsed options
*/ */
Ptr<Options> parseOptions(int argc, Ptr<Options> parseOptions(int argc,
char** argv, char** argv,

View File

@ -119,10 +119,10 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
cli.add<std::vector<std::string>>("--config,-c", cli.add<std::vector<std::string>>("--config,-c",
"Configuration file(s). If multiple, later overrides earlier"); "Configuration file(s). If multiple, later overrides earlier");
cli.add<size_t>("--workspace,-w", cli.add<size_t>("--workspace,-w",
"Preallocate arg MB of work space", "Preallocate arg MB of work space",
defaultWorkspace); defaultWorkspace);
cli.add<std::string>("--log", cli.add<std::string>("--log",
"Log training process information to file given by arg"); "Log training process information to file given by arg");
cli.add<std::string>("--log-level", cli.add<std::string>("--log-level",
"Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off", "Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
"info"); "info");
@ -392,17 +392,17 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates", "Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
"0e"); "0e");
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq", cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
"Display information every arg updates (append 't' for every arg target labels)", "Display information every arg updates (append 't' for every arg target labels)",
"1000u"); "1000u");
cli.add<size_t>("--disp-first", cli.add<size_t>("--disp-first",
"Display information for the first arg updates"); "Display information for the first arg updates");
cli.add<bool>("--disp-label-counts", cli.add<bool>("--disp-label-counts",
"Display label counts when logging loss progress", "Display label counts when logging loss progress",
true); true);
// cli.add<int>("--disp-label-index", // cli.add<int>("--disp-label-index",
// "Display label counts based on i-th input stream (-1 is last)", -1); // "Display label counts based on i-th input stream (-1 is last)", -1);
cli.add<std::string/*SchedulerPeriod*/>("--save-freq", cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
"Save model file every arg updates (append 't' for every arg target labels)", "Save model file every arg updates (append 't' for every arg target labels)",
"10000u"); "10000u");
cli.add<std::vector<std::string>>("--logical-epoch", cli.add<std::vector<std::string>>("--logical-epoch",
"Redefine logical epoch counter as multiple of data epochs (e.g. 1e), updates (e.g. 100Ku) or labels (e.g. 1Gt). " "Redefine logical epoch counter as multiple of data epochs (e.g. 1e), updates (e.g. 100Ku) or labels (e.g. 1Gt). "
@ -473,12 +473,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<bool>("--lr-decay-repeat-warmup", cli.add<bool>("--lr-decay-repeat-warmup",
"Repeat learning rate warmup when learning rate is decayed"); "Repeat learning rate warmup when learning rate is decayed");
cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt", cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). " "Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
"Add second argument to define the starting point (default: same as first value)", "Add second argument to define the starting point (default: same as first value)",
{"0"}); {"0"});
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup", cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)", "Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
"0"); "0");
cli.add<float>("--lr-warmup-start-rate", cli.add<float>("--lr-warmup-start-rate",
"Start value for learning rate warmup"); "Start value for learning rate warmup");
@ -492,7 +492,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<double>("--factor-weight", cli.add<double>("--factor-weight",
"Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f); "Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
cli.add<float>("--clip-norm", cli.add<float>("--clip-norm",
"Clip gradient norm to arg (0 to disable)", "Clip gradient norm to arg (0 to disable)",
1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels 1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
cli.add<float>("--exponential-smoothing", cli.add<float>("--exponential-smoothing",
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. " "Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
@ -575,7 +575,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<std::vector<std::string>>("--valid-sets", cli.add<std::vector<std::string>>("--valid-sets",
"Paths to validation corpora: source target"); "Paths to validation corpora: source target");
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq", cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
"Validate model every arg updates (append 't' for every arg target labels)", "Validate model every arg updates (append 't' for every arg target labels)",
"10000u"); "10000u");
cli.add<std::vector<std::string>>("--valid-metrics", cli.add<std::vector<std::string>>("--valid-metrics",
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, " "Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
@ -585,7 +585,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<bool>("--valid-reset-stalled", cli.add<bool>("--valid-reset-stalled",
"Reset all stalled validation metrics when the training is restarted"); "Reset all stalled validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping", cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps", "Stop if the first validation metric does not improve for arg consecutive validation steps",
10); 10);
cli.add<std::string>("--early-stopping-on", cli.add<std::string>("--early-stopping-on",
"Decide if early stopping should take into account first, all, or any validation metrics" "Decide if early stopping should take into account first, all, or any validation metrics"
@ -637,7 +637,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<bool>("--keep-best", cli.add<bool>("--keep-best",
"Keep best model for each validation metric"); "Keep best model for each validation metric");
cli.add<std::string>("--valid-log", cli.add<std::string>("--valid-log",
"Log validation scores to file given by arg"); "Log validation scores to file given by arg");
cli.switchGroup(previous_group); cli.switchGroup(previous_group);
// clang-format on // clang-format on
} }
@ -942,10 +942,10 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
cli.add<std::string>("--ulr-query-vectors", cli.add<std::string>("--ulr-query-vectors",
"Path to file with universal sources embeddings from projection into universal space", "Path to file with universal sources embeddings from projection into universal space",
""); "");
// keys: EK in Fig2 : is the keys of the target embbedings projected to unified space (i.e. ENU in // keys: EK in Fig2 : is the keys of the target embeddings projected to unified space (i.e. ENU in
// multi-lingual case) // multi-lingual case)
cli.add<std::string>("--ulr-keys-vectors", cli.add<std::string>("--ulr-keys-vectors",
"Path to file with universal sources embeddings of traget keys from projection into universal space", "Path to file with universal sources embeddings of target keys from projection into universal space",
""); "");
cli.add<bool>("--ulr-trainable-transformation", cli.add<bool>("--ulr-trainable-transformation",
"Make Query Transformation Matrix A trainable"); "Make Query Transformation Matrix A trainable");

View File

@ -10,7 +10,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur. #define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is necessary, remove if no problems occur.
#define NodeOp(op) [=]() { op; } #define NodeOp(op) [=]() { op; }
// helper macro to disable optimization (gcc only) // helper macro to disable optimization (gcc only)

View File

@ -136,11 +136,11 @@ static void setErrorHandlers() {
// modify the log pattern for the "general" logger to include the MPI rank // modify the log pattern for the "general" logger to include the MPI rank
// This is called upon initializing MPI. It is needed to associated error messages to ranks. // This is called upon initializing MPI. It is needed to associated error messages to ranks.
void switchtoMultinodeLogging(std::string nodeIdStr) { void switchToMultinodeLogging(std::string nodeIdStr) {
Logger log = spdlog::get("general"); Logger log = spdlog::get("general");
if(log) if(log)
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr)); log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));
Logger valid = spdlog::get("valid"); Logger valid = spdlog::get("valid");
if(valid) if(valid)
valid->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] [valid] %v", nodeIdStr)); valid->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] [valid] %v", nodeIdStr));

View File

@ -12,19 +12,19 @@ namespace marian {
std::string getCallStack(size_t skipLevels); std::string getCallStack(size_t skipLevels);
// Marian gives a basic exception guarantee. If you catch a // Marian gives a basic exception guarantee. If you catch a
// MarianRuntimeError you must assume that the object can be // MarianRuntimeError you must assume that the object can be
// safely destructed, but cannot be used otherwise. // safely destructed, but cannot be used otherwise.
// Internal multi-threading in exception-throwing mode is not // Internal multi-threading in exception-throwing mode is not
// allowed; and constructing a thread-pool will cause an exception. // allowed; and constructing a thread-pool will cause an exception.
class MarianRuntimeException : public std::runtime_error { class MarianRuntimeException : public std::runtime_error {
private: private:
std::string callStack_; std::string callStack_;
public: public:
MarianRuntimeException(const std::string& message, const std::string& callStack) MarianRuntimeException(const std::string& message, const std::string& callStack)
: std::runtime_error(message), : std::runtime_error(message),
callStack_(callStack) {} callStack_(callStack) {}
const char* getCallStack() const throw() { const char* getCallStack() const throw() {
@ -178,4 +178,4 @@ void checkedLog(std::string logger, std::string level, Args... args) {
} }
void createLoggers(const marian::Config* options = nullptr); void createLoggers(const marian::Config* options = nullptr);
void switchtoMultinodeLogging(std::string nodeIdStr); void switchToMultinodeLogging(std::string nodeIdStr);

View File

@ -98,7 +98,7 @@ public:
* @brief Splice options from a YAML node * @brief Splice options from a YAML node
* *
* By default, only options with keys that do not already exist in options_ are extracted from * By default, only options with keys that do not already exist in options_ are extracted from
* node. These options are cloned if overwirte is true. * node. These options are cloned if overwrite is true.
* *
* @param node a YAML node to transfer the options from * @param node a YAML node to transfer the options from
* @param overwrite overwrite all options * @param overwrite overwrite all options

View File

@ -379,7 +379,7 @@ public:
* @see marian::data::SubBatch::split(size_t n) * @see marian::data::SubBatch::split(size_t n)
*/ */
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override { std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
ABORT_IF(size() == 0, "Encoutered batch size of 0"); ABORT_IF(size() == 0, "Encountered batch size of 0");
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex] std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
// split each stream separately // split each stream separately
@ -523,8 +523,8 @@ class CorpusBase : public DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch
public: public:
typedef SentenceTuple Sample; typedef SentenceTuple Sample;
CorpusBase(Ptr<Options> options, CorpusBase(Ptr<Options> options,
bool translate = false, bool translate = false,
size_t seed = Config::seed); size_t seed = Config::seed);
CorpusBase(const std::vector<std::string>& paths, CorpusBase(const std::vector<std::string>& paths,

View File

@ -44,7 +44,7 @@ public:
virtual void prepare() {} virtual void prepare() {}
virtual void restore(Ptr<TrainingState>) {} virtual void restore(Ptr<TrainingState>) {}
// @TODO: remove after cleaning traininig/training.h // @TODO: remove after cleaning training/training.h
virtual Ptr<Options> options() { return options_; } virtual Ptr<Options> options() { return options_; }
}; };

View File

@ -10,7 +10,7 @@ namespace marian {
class IVocab; class IVocab;
// Wrapper around vocabulary types. Can choose underlying // Wrapper around vocabulary types. Can choose underlying
// vocabulary implementation (vImpl_) based on speficied path // vocabulary implementation (vImpl_) based on specified path
// and suffix. // and suffix.
// Vocabulary implementations can currently be: // Vocabulary implementations can currently be:
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending) // * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)

View File

@ -76,7 +76,7 @@ public:
Ptr<Allocator> getAllocator() { return tensors_->allocator(); } Ptr<Allocator> getAllocator() { return tensors_->allocator(); }
Ptr<TensorAllocator> getTensorAllocator() { return tensors_; } Ptr<TensorAllocator> getTensorAllocator() { return tensors_; }
Expr findOrRemember(Expr node) { Expr findOrRemember(Expr node) {
size_t hash = node->hash(); size_t hash = node->hash();
// memoize constant nodes that are not parameters // memoize constant nodes that are not parameters
@ -359,9 +359,9 @@ private:
// Find the named parameter and its typed parent parameter object (params) and return both. // Find the named parameter and its typed parent parameter object (params) and return both.
// If the parameter is not found return the parent parameter object that the parameter should be added to. // If the parameter is not found return the parent parameter object that the parameter should be added to.
// Return [nullptr, nullptr] if no matching parent parameter object exists. // Return [nullptr, nullptr] if no matching parent parameter object exists.
std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name, std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
Type elementType, Type elementType,
bool typeSpecified) const { bool typeSpecified) const {
Expr p; Ptr<Parameters> params; Expr p; Ptr<Parameters> params;
if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type
@ -373,12 +373,12 @@ private:
} else { // type has not been specified, so we take any type as long as the name matches } else { // type has not been specified, so we take any type as long as the name matches
for(auto kvParams : paramsByElementType_) { for(auto kvParams : paramsByElementType_) {
p = kvParams.second->get(name); p = kvParams.second->get(name);
if(p) { // p has been found, return with matching params object if(p) { // p has been found, return with matching params object
params = kvParams.second; params = kvParams.second;
break; break;
} }
if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned
params = kvParams.second; params = kvParams.second;
} }
@ -399,8 +399,8 @@ private:
Expr p; Ptr<Parameters> params; std::tie Expr p; Ptr<Parameters> params; std::tie
(p, params) = findParams(name, elementType, typeSpecified); (p, params) = findParams(name, elementType, typeSpecified);
if(!params) { if(!params) {
params = New<Parameters>(elementType); params = New<Parameters>(elementType);
params->init(backend_); params->init(backend_);
paramsByElementType_.insert({elementType, params}); paramsByElementType_.insert({elementType, params});
@ -632,13 +632,13 @@ public:
* Return the Parameters object related to the graph. * Return the Parameters object related to the graph.
* The Parameters object holds the whole set of the parameter nodes. * The Parameters object holds the whole set of the parameter nodes.
*/ */
Ptr<Parameters>& params() { Ptr<Parameters>& params() {
// There are no parameter objects, that's weird. // There are no parameter objects, that's weird.
ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created"); ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created");
// Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported // Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported
ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size()); ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size());
// Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported // Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported
auto it = paramsByElementType_.find(defaultElementType_); auto it = paramsByElementType_.find(defaultElementType_);
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_); ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
@ -650,7 +650,7 @@ public:
* Return the Parameters object related to the graph by elementType. * Return the Parameters object related to the graph by elementType.
* The Parameters object holds the whole set of the parameter nodes of the given type. * The Parameters object holds the whole set of the parameter nodes of the given type.
*/ */
Ptr<Parameters>& params(Type elementType) { Ptr<Parameters>& params(Type elementType) {
auto it = paramsByElementType_.find(elementType); auto it = paramsByElementType_.find(elementType);
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_); ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
return it->second; return it->second;
@ -661,8 +661,8 @@ public:
* The default value is used if some node type is not specified. * The default value is used if some node type is not specified.
*/ */
void setDefaultElementType(Type defaultElementType) { void setDefaultElementType(Type defaultElementType) {
ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_, ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
"Parameter objects already exist, cannot change default type from {} to {}", "Parameter objects already exist, cannot change default type from {} to {}",
defaultElementType_, defaultElementType); defaultElementType_, defaultElementType);
defaultElementType_ = defaultElementType; defaultElementType_ = defaultElementType;
} }
@ -746,7 +746,7 @@ public:
// skip over special parameters starting with "special:" // skip over special parameters starting with "special:"
if(pName.substr(0, 8) == "special:") if(pName.substr(0, 8) == "special:")
continue; continue;
// if during loading the loaded type is of the same type class as the default element type, allow conversion; // if during loading the loaded type is of the same type class as the default element type, allow conversion;
// otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both // otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both
// have type class TypeClass::float_type. // have type class TypeClass::float_type.
@ -781,9 +781,9 @@ public:
LOG(info, "Memory mapping model at {}", ptr); LOG(info, "Memory mapping model at {}", ptr);
auto items = io::mmapItems(ptr); auto items = io::mmapItems(ptr);
// Deal with default parameter set object that might not be a mapped object. // Deal with default parameter set object that might not be a mapped object.
// This gets assigned during ExpressionGraph::setDevice(...) and by default // This gets assigned during ExpressionGraph::setDevice(...) and by default
// would contain allocated tensors. Here we replace it with a mmapped version. // would contain allocated tensors. Here we replace it with a mmapped version.
auto it = paramsByElementType_.find(defaultElementType_); auto it = paramsByElementType_.find(defaultElementType_);
if(it != paramsByElementType_.end()) { if(it != paramsByElementType_.end()) {

View File

@ -27,12 +27,12 @@ Expr checkpoint(Expr a) {
return a; return a;
} }
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, size_t hash) { LambdaNodeFunctor fwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash); return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
} }
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) { LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash); return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
} }
@ -436,7 +436,7 @@ Expr std(Expr a, int ax) {
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms); return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
} }
Expr var(Expr a, int ax) { Expr var(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0 if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
return a - a; return a - a;
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr); return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
@ -575,8 +575,8 @@ Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float sc
return Expression<AffineNodeOp>(nodes, transA, transB, scale); return Expression<AffineNodeOp>(nodes, transA, transB, scale);
} }
// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future. // This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
// The last branch with auto-tuner is: // The last branch with auto-tuner is:
// youki/packed-model-pr-backup1031 // youki/packed-model-pr-backup1031
// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031 // https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52 // SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
@ -660,8 +660,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
} }
} else { } else {
// Default GEMM // Default GEMM
ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType), ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
"GPU-based GEMM only supports float types, you have A: {} and B: {}", "GPU-based GEMM only supports float types, you have A: {} and B: {}",
aElementType, bElementType); aElementType, bElementType);
return affineDefault(a, b, bias, transA, transB, scale); return affineDefault(a, b, bias, transA, transB, scale);
} }
@ -669,7 +669,7 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto graph = a->graph(); auto graph = a->graph();
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu) if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale); return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
else else
@ -775,7 +775,7 @@ Expr unlikelihood(Expr logits, Expr indices) {
int dimBatch = logits->shape()[-2]; int dimBatch = logits->shape()[-2];
int dimTime = logits->shape()[-3]; int dimTime = logits->shape()[-3];
// @TODO: fix this outside of this function in decoder.h etc. // @TODO: fix this outside of this function in decoder.h etc.
auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1}); auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
// This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy // This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy

View File

@ -70,9 +70,9 @@ public:
outputs.push_back(output2); outputs.push_back(output2);
} }
auto concated = concatenate(outputs, -1); auto concatenated = concatenate(outputs, -1);
return concated; return concatenated;
} }
protected: protected:

View File

@ -67,7 +67,7 @@ public:
return count_->val()->scalar<T>(); return count_->val()->scalar<T>();
} }
// @TODO: add a funtion for returning maybe ratio? // @TODO: add a function for returning maybe ratio?
size_t size() const { size_t size() const {
ABORT_IF(!count_, "Labels have not been defined"); ABORT_IF(!count_, "Labels have not been defined");
@ -189,7 +189,7 @@ public:
* *
* L = sum_i^N L_i + N/M sum_j^M L_j * L = sum_i^N L_i + N/M sum_j^M L_j
* *
* We set labels to N. When reporting L/N this is equvalient to sum of means. * We set labels to N. When reporting L/N this is equivalent to sum of means.
* Compare to sum of means below where N is factored into the loss, but labels * Compare to sum of means below where N is factored into the loss, but labels
* are set to 1. * are set to 1.
*/ */

View File

@ -76,7 +76,7 @@ private:
float scale = sqrtf(2.0f / (dimVoc + dimEmb)); float scale = sqrtf(2.0f / (dimVoc + dimEmb));
// @TODO: switch to new random generator back-end. // @TODO: switch to new random generator back-end.
// This is rarly used however. // This is rarely used however.
std::random_device rd; std::random_device rd;
std::mt19937 engine(rd()); std::mt19937 engine(rd());

View File

@ -8,7 +8,7 @@
namespace marian { namespace marian {
/** /**
* This file contains nearly all BERT-related code and adds BERT-funtionality * This file contains nearly all BERT-related code and adds BERT-functionality
* on top of existing classes like TansformerEncoder and Classifier. * on top of existing classes like TansformerEncoder and Classifier.
*/ */
@ -82,7 +82,7 @@ public:
// Initialize to sample random vocab id // Initialize to sample random vocab id
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size())); randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
// Intialize to sample random percentage // Initialize to sample random percentage
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f)); randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
auto& words = subBatch->data(); auto& words = subBatch->data();

View File

@ -14,7 +14,7 @@ namespace marian {
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc. * Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
* Already has support for multi-objective training. * Already has support for multi-objective training.
* *
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier * @TODO: this should probably be unified somehow with EncoderDecoder which could allow for decoder/classifier
* multi-objective training. * multi-objective training.
*/ */
class EncoderClassifierBase : public models::IModel { class EncoderClassifierBase : public models::IModel {

View File

@ -220,7 +220,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
if(clearGraph) if(clearGraph)
clear(graph); clear(graph);
// Required first step, also intializes shortlist // Required first step, also initializes shortlist
auto state = startState(graph, batch); auto state = startState(graph, batch);
// Fill state with embeddings from batch (ground truth) // Fill state with embeddings from batch (ground truth)

View File

@ -70,7 +70,7 @@ public:
// Hack for translating with length longer than trained embeddings // Hack for translating with length longer than trained embeddings
// We check if the embedding matrix "Wpos" already exist so we can // We check if the embedding matrix "Wpos" already exist so we can
// check the number of positions in that loaded parameter. // check the number of positions in that loaded parameter.
// We then have to restict the maximum length to the maximum positon // We then have to restrict the maximum length to the maximum positon
// and positions beyond this will be the maximum position. // and positions beyond this will be the maximum position.
Expr seenEmb = graph_->get("Wpos"); Expr seenEmb = graph_->get("Wpos");
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength; int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;

View File

@ -101,7 +101,7 @@ public:
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1); std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
rankStr.insert(rankStr.begin(), ' '); rankStr.insert(rankStr.begin(), ' ');
switchtoMultinodeLogging(rankStr); switchToMultinodeLogging(rankStr);
} }
// log hostnames in order, and test // log hostnames in order, and test
@ -261,7 +261,7 @@ void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible."); ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
mpi = nullptr; // destruct caller's handle mpi = nullptr; // destruct caller's handle
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible."); ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
s_mpi->finalize(); // signal successful completion to MPI s_mpi->finalize(); // signal successful completion to MPI
s_mpi = nullptr; // release the singleton instance upon last finalization s_mpi = nullptr; // release the singleton instance upon last finalization

View File

@ -13,7 +13,7 @@ namespace marian {
// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here. // With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
static inline bool isFinite(float x) { static inline bool isFinite(float x) {
#ifdef __GNUC__ #ifdef __GNUC__
ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option."); ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
#endif #endif
return std::isfinite(x); return std::isfinite(x);
@ -27,7 +27,7 @@ static inline bool isFinite(float x) {
// if one value is nonfinite propagate Nan into the reduction. // if one value is nonfinite propagate Nan into the reduction.
static inline void accNanOrNorm(float& lhs, float rhs) { static inline void accNanOrNorm(float& lhs, float rhs) {
if(isFinite(lhs) && isFinite(rhs)) { if(isFinite(lhs) && isFinite(rhs)) {
lhs = sqrtf(lhs * lhs + rhs * rhs); lhs = sqrtf(lhs * lhs + rhs * rhs);
} else } else
lhs = std::numeric_limits<float>::quiet_NaN(); lhs = std::numeric_limits<float>::quiet_NaN();
} }
@ -42,20 +42,20 @@ static inline void accNanOrNorm(float& lhs, float rhs) {
class GraphGroup { class GraphGroup {
protected: protected:
Ptr<Options> options_; Ptr<Options> options_;
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run) Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
std::vector<DeviceId> devices_; // [deviceIndex] std::vector<DeviceId> devices_; // [deviceIndex]
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM). ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
// common for all graph groups, individual graph groups decide how to fill them // common for all graph groups, individual graph groups decide how to fill them
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex] std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex] std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex] std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex]
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed) bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
@ -100,16 +100,16 @@ public:
virtual void load(); virtual void load();
virtual void save(bool isFinal = false); virtual void save(bool isFinal = false);
private: private:
void load(const OptimizerBase::ScatterStateFunc& scatterFn); void load(const OptimizerBase::ScatterStateFunc& scatterFn);
void save(bool isFinal, void save(bool isFinal,
const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn); const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);
bool restoreFromCheckpoint(const std::string& modelFileName, bool restoreFromCheckpoint(const std::string& modelFileName,
const OptimizerBase::ScatterStateFunc& scatterFn); const OptimizerBase::ScatterStateFunc& scatterFn);
void saveCheckpoint(const std::string& modelFileName, void saveCheckpoint(const std::string& modelFileName,
const OptimizerBase::GatherStateFunc& gatherFn); const OptimizerBase::GatherStateFunc& gatherFn);
public: public:
@ -128,11 +128,11 @@ public:
float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task); float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);
float computeNormalizationFactor(float gNorm, size_t updateTrgWords); float computeNormalizationFactor(float gNorm, size_t updateTrgWords);
/** /**
* Determine maximal batch size that can fit into the given workspace * Determine maximal batch size that can fit into the given workspace
* so that reallocation does not happen. Rather adjust the batch size * so that reallocation does not happen. Rather adjust the batch size
* based on the stastistics collected here. Activated with * based on the statistics collected here. Activated with
* `--mini-batch-fit`. * `--mini-batch-fit`.
* In a multi-GPU scenario, the first GPU is used to determine the size. * In a multi-GPU scenario, the first GPU is used to determine the size.
* The actual allowed size is then determined by multiplying it with the * The actual allowed size is then determined by multiplying it with the
@ -151,4 +151,4 @@ public:
void updateAverageTrgBatchWords(size_t trgBatchWords); void updateAverageTrgBatchWords(size_t trgBatchWords);
}; };
} // namespace marian } // namespace marian