mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Cherry picked cleaning/refeactoring patches (#905)
Cherry-picked updates from pull request #457 Co-authored-by: Mateusz Chudyk <mateuszchudyk@gmail.com>
This commit is contained in:
parent
71b5454b9e
commit
07c39c7d76
@ -113,10 +113,10 @@ std::string CLIWrapper::switchGroup(std::string name) {
|
|||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CLIWrapper::parse(int argc, char **argv) {
|
void CLIWrapper::parse(int argc, char** argv) {
|
||||||
try {
|
try {
|
||||||
app_->parse(argc, argv);
|
app_->parse(argc, argv);
|
||||||
} catch(const CLI::ParseError &e) {
|
} catch(const CLI::ParseError& e) {
|
||||||
exit(app_->exit(e));
|
exit(app_->exit(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,6 +182,13 @@ void CLIWrapper::parseAliases() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string CLIWrapper::keyName(const std::string& args) const {
|
||||||
|
// re-use existing functions from CLI11 to keep option names consistent
|
||||||
|
return std::get<1>(
|
||||||
|
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
|
||||||
|
.front(); // get first long name
|
||||||
|
}
|
||||||
|
|
||||||
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
|
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
|
||||||
auto cmdOptions = getParsedOptionNames();
|
auto cmdOptions = getParsedOptionNames();
|
||||||
// Keep track of unrecognized options from the provided config
|
// Keep track of unrecognized options from the provided config
|
||||||
@ -276,7 +283,7 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
|
|||||||
for(auto const &it : options_)
|
for(auto const &it : options_)
|
||||||
keys.push_back(it.first);
|
keys.push_back(it.first);
|
||||||
// sort option names by creation index
|
// sort option names by creation index
|
||||||
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) {
|
sort(keys.begin(), keys.end(), [this](const std::string& a, const std::string& b) {
|
||||||
return options_.at(a).idx < options_.at(b).idx;
|
return options_.at(a).idx < options_.at(b).idx;
|
||||||
});
|
});
|
||||||
return keys;
|
return keys;
|
||||||
|
@ -44,7 +44,7 @@ struct CLIAliasTuple {
|
|||||||
class CLIFormatter : public CLI::Formatter {
|
class CLIFormatter : public CLI::Formatter {
|
||||||
public:
|
public:
|
||||||
CLIFormatter(size_t columnWidth, size_t screenWidth);
|
CLIFormatter(size_t columnWidth, size_t screenWidth);
|
||||||
virtual std::string make_option_desc(const CLI::Option *) const override;
|
virtual std::string make_option_desc(const CLI::Option*) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t screenWidth_{0};
|
size_t screenWidth_{0};
|
||||||
@ -85,12 +85,7 @@ private:
|
|||||||
|
|
||||||
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
|
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
|
||||||
// '--help,-h'
|
// '--help,-h'
|
||||||
std::string keyName(const std::string &args) const {
|
std::string keyName(const std::string &args) const;
|
||||||
// re-use existing functions from CLI11 to keep option names consistent
|
|
||||||
return std::get<1>(
|
|
||||||
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
|
|
||||||
.front(); // get first long name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get names of options passed via command-line
|
// Get names of options passed via command-line
|
||||||
std::unordered_set<std::string> getParsedOptionNames() const;
|
std::unordered_set<std::string> getParsedOptionNames() const;
|
||||||
@ -134,7 +129,7 @@ public:
|
|||||||
* @return Option object
|
* @return Option object
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CLI::Option *add(const std::string &args, const std::string &help, T val) {
|
CLI::Option* add(const std::string& args, const std::string& help, T val) {
|
||||||
return addOption<T>(keyName(args),
|
return addOption<T>(keyName(args),
|
||||||
args,
|
args,
|
||||||
help,
|
help,
|
||||||
@ -159,7 +154,7 @@ public:
|
|||||||
* @TODO: require to always state the default value creating the parser as this will be clearer
|
* @TODO: require to always state the default value creating the parser as this will be clearer
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CLI::Option *add(const std::string &args, const std::string &help) {
|
CLI::Option* add(const std::string& args, const std::string& help) {
|
||||||
return addOption<T>(keyName(args),
|
return addOption<T>(keyName(args),
|
||||||
args,
|
args,
|
||||||
help,
|
help,
|
||||||
@ -206,7 +201,7 @@ public:
|
|||||||
std::string switchGroup(std::string name = "");
|
std::string switchGroup(std::string name = "");
|
||||||
|
|
||||||
// Parse command-line arguments. Handles --help and --version options
|
// Parse command-line arguments. Handles --help and --version options
|
||||||
void parse(int argc, char **argv);
|
void parse(int argc, char** argv);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Expand aliases based on arguments parsed with parse(int, char**)
|
* @brief Expand aliases based on arguments parsed with parse(int, char**)
|
||||||
@ -240,11 +235,12 @@ public:
|
|||||||
std::string dumpConfig(bool skipUnmodified = false) const;
|
std::string dumpConfig(bool skipUnmodified = false) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T,
|
template <typename T>
|
||||||
// options with numeric and string-like values
|
using EnableIfNumbericOrString = CLI::enable_if_t<!CLI::is_bool<T>::value
|
||||||
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
|
&& !CLI::is_vector<T>::value, CLI::detail::enabler>;
|
||||||
CLI::detail::enabler> = CLI::detail::dummy>
|
|
||||||
CLI::Option *addOption(const std::string &key,
|
template <typename T, EnableIfNumbericOrString<T> = CLI::detail::dummy>
|
||||||
|
CLI::Option* addOption(const std::string &key,
|
||||||
const std::string &args,
|
const std::string &args,
|
||||||
const std::string &help,
|
const std::string &help,
|
||||||
T val,
|
T val,
|
||||||
@ -261,7 +257,7 @@ private:
|
|||||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||||
// get variable associated with the option
|
// get variable associated with the option
|
||||||
auto &var = options_[key].var->as<T>();
|
auto& var = options_[key].var->as<T>();
|
||||||
// store parser result in var
|
// store parser result in var
|
||||||
auto ret = CLI::detail::lexical_cast(res[0], var);
|
auto ret = CLI::detail::lexical_cast(res[0], var);
|
||||||
// update YAML entry
|
// update YAML entry
|
||||||
@ -288,10 +284,11 @@ private:
|
|||||||
return options_[key].opt;
|
return options_[key].opt;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T>
|
||||||
// options with vector values
|
using EnableIfVector = CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler>;
|
||||||
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
|
||||||
CLI::Option *addOption(const std::string &key,
|
template <typename T, EnableIfVector<T> = CLI::detail::dummy>
|
||||||
|
CLI::Option* addOption(const std::string &key,
|
||||||
const std::string &args,
|
const std::string &args,
|
||||||
const std::string &help,
|
const std::string &help,
|
||||||
T val,
|
T val,
|
||||||
@ -308,7 +305,7 @@ private:
|
|||||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||||
// get vector variable associated with the option
|
// get vector variable associated with the option
|
||||||
auto &vec = options_[key].var->as<T>();
|
auto& vec = options_[key].var->as<T>();
|
||||||
vec.clear();
|
vec.clear();
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
// handle '[]' as an empty vector
|
// handle '[]' as an empty vector
|
||||||
@ -316,7 +313,7 @@ private:
|
|||||||
ret = true;
|
ret = true;
|
||||||
} else {
|
} else {
|
||||||
// populate the vector with parser results
|
// populate the vector with parser results
|
||||||
for(const auto &a : res) {
|
for(const auto& a : res) {
|
||||||
vec.emplace_back();
|
vec.emplace_back();
|
||||||
ret &= CLI::detail::lexical_cast(a, vec.back());
|
ret &= CLI::detail::lexical_cast(a, vec.back());
|
||||||
}
|
}
|
||||||
@ -345,10 +342,11 @@ private:
|
|||||||
return options_[key].opt;
|
return options_[key].opt;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T>
|
||||||
// options with boolean values, called flags in CLI11
|
using EnableIfBoolean = CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler>;
|
||||||
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
|
||||||
CLI::Option *addOption(const std::string &key,
|
template <typename T, EnableIfBoolean<T> = CLI::detail::dummy>
|
||||||
|
CLI::Option* addOption(const std::string &key,
|
||||||
const std::string &args,
|
const std::string &args,
|
||||||
const std::string &help,
|
const std::string &help,
|
||||||
T val,
|
T val,
|
||||||
|
@ -107,7 +107,7 @@ private:
|
|||||||
* @param mode change the set of available command-line options, e.g. training, translation, etc.
|
* @param mode change the set of available command-line options, e.g. training, translation, etc.
|
||||||
* @param validate validate parsed options and abort on failure
|
* @param validate validate parsed options and abort on failure
|
||||||
*
|
*
|
||||||
* @return parsed otions
|
* @return parsed options
|
||||||
*/
|
*/
|
||||||
Ptr<Options> parseOptions(int argc,
|
Ptr<Options> parseOptions(int argc,
|
||||||
char** argv,
|
char** argv,
|
||||||
|
@ -119,10 +119,10 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
|
|||||||
cli.add<std::vector<std::string>>("--config,-c",
|
cli.add<std::vector<std::string>>("--config,-c",
|
||||||
"Configuration file(s). If multiple, later overrides earlier");
|
"Configuration file(s). If multiple, later overrides earlier");
|
||||||
cli.add<size_t>("--workspace,-w",
|
cli.add<size_t>("--workspace,-w",
|
||||||
"Preallocate arg MB of work space",
|
"Preallocate arg MB of work space",
|
||||||
defaultWorkspace);
|
defaultWorkspace);
|
||||||
cli.add<std::string>("--log",
|
cli.add<std::string>("--log",
|
||||||
"Log training process information to file given by arg");
|
"Log training process information to file given by arg");
|
||||||
cli.add<std::string>("--log-level",
|
cli.add<std::string>("--log-level",
|
||||||
"Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
|
"Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
|
||||||
"info");
|
"info");
|
||||||
@ -392,17 +392,17 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||||||
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
|
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
|
||||||
"0e");
|
"0e");
|
||||||
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
|
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
|
||||||
"Display information every arg updates (append 't' for every arg target labels)",
|
"Display information every arg updates (append 't' for every arg target labels)",
|
||||||
"1000u");
|
"1000u");
|
||||||
cli.add<size_t>("--disp-first",
|
cli.add<size_t>("--disp-first",
|
||||||
"Display information for the first arg updates");
|
"Display information for the first arg updates");
|
||||||
cli.add<bool>("--disp-label-counts",
|
cli.add<bool>("--disp-label-counts",
|
||||||
"Display label counts when logging loss progress",
|
"Display label counts when logging loss progress",
|
||||||
true);
|
true);
|
||||||
// cli.add<int>("--disp-label-index",
|
// cli.add<int>("--disp-label-index",
|
||||||
// "Display label counts based on i-th input stream (-1 is last)", -1);
|
// "Display label counts based on i-th input stream (-1 is last)", -1);
|
||||||
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
|
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
|
||||||
"Save model file every arg updates (append 't' for every arg target labels)",
|
"Save model file every arg updates (append 't' for every arg target labels)",
|
||||||
"10000u");
|
"10000u");
|
||||||
cli.add<std::vector<std::string>>("--logical-epoch",
|
cli.add<std::vector<std::string>>("--logical-epoch",
|
||||||
"Redefine logical epoch counter as multiple of data epochs (e.g. 1e), updates (e.g. 100Ku) or labels (e.g. 1Gt). "
|
"Redefine logical epoch counter as multiple of data epochs (e.g. 1e), updates (e.g. 100Ku) or labels (e.g. 1Gt). "
|
||||||
@ -473,12 +473,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||||||
cli.add<bool>("--lr-decay-repeat-warmup",
|
cli.add<bool>("--lr-decay-repeat-warmup",
|
||||||
"Repeat learning rate warmup when learning rate is decayed");
|
"Repeat learning rate warmup when learning rate is decayed");
|
||||||
cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
|
cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
|
||||||
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
|
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
|
||||||
"Add second argument to define the starting point (default: same as first value)",
|
"Add second argument to define the starting point (default: same as first value)",
|
||||||
{"0"});
|
{"0"});
|
||||||
|
|
||||||
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
|
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
|
||||||
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
|
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
|
||||||
"0");
|
"0");
|
||||||
cli.add<float>("--lr-warmup-start-rate",
|
cli.add<float>("--lr-warmup-start-rate",
|
||||||
"Start value for learning rate warmup");
|
"Start value for learning rate warmup");
|
||||||
@ -492,7 +492,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||||||
cli.add<double>("--factor-weight",
|
cli.add<double>("--factor-weight",
|
||||||
"Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
|
"Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
|
||||||
cli.add<float>("--clip-norm",
|
cli.add<float>("--clip-norm",
|
||||||
"Clip gradient norm to arg (0 to disable)",
|
"Clip gradient norm to arg (0 to disable)",
|
||||||
1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
|
1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
|
||||||
cli.add<float>("--exponential-smoothing",
|
cli.add<float>("--exponential-smoothing",
|
||||||
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
|
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
|
||||||
@ -575,7 +575,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||||||
cli.add<std::vector<std::string>>("--valid-sets",
|
cli.add<std::vector<std::string>>("--valid-sets",
|
||||||
"Paths to validation corpora: source target");
|
"Paths to validation corpora: source target");
|
||||||
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
|
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
|
||||||
"Validate model every arg updates (append 't' for every arg target labels)",
|
"Validate model every arg updates (append 't' for every arg target labels)",
|
||||||
"10000u");
|
"10000u");
|
||||||
cli.add<std::vector<std::string>>("--valid-metrics",
|
cli.add<std::vector<std::string>>("--valid-metrics",
|
||||||
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
|
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
|
||||||
@ -585,7 +585,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||||||
cli.add<bool>("--valid-reset-stalled",
|
cli.add<bool>("--valid-reset-stalled",
|
||||||
"Reset all stalled validation metrics when the training is restarted");
|
"Reset all stalled validation metrics when the training is restarted");
|
||||||
cli.add<size_t>("--early-stopping",
|
cli.add<size_t>("--early-stopping",
|
||||||
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
||||||
10);
|
10);
|
||||||
cli.add<std::string>("--early-stopping-on",
|
cli.add<std::string>("--early-stopping-on",
|
||||||
"Decide if early stopping should take into account first, all, or any validation metrics"
|
"Decide if early stopping should take into account first, all, or any validation metrics"
|
||||||
@ -637,7 +637,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||||||
cli.add<bool>("--keep-best",
|
cli.add<bool>("--keep-best",
|
||||||
"Keep best model for each validation metric");
|
"Keep best model for each validation metric");
|
||||||
cli.add<std::string>("--valid-log",
|
cli.add<std::string>("--valid-log",
|
||||||
"Log validation scores to file given by arg");
|
"Log validation scores to file given by arg");
|
||||||
cli.switchGroup(previous_group);
|
cli.switchGroup(previous_group);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
@ -942,10 +942,10 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
|||||||
cli.add<std::string>("--ulr-query-vectors",
|
cli.add<std::string>("--ulr-query-vectors",
|
||||||
"Path to file with universal sources embeddings from projection into universal space",
|
"Path to file with universal sources embeddings from projection into universal space",
|
||||||
"");
|
"");
|
||||||
// keys: EK in Fig2 : is the keys of the target embbedings projected to unified space (i.e. ENU in
|
// keys: EK in Fig2 : is the keys of the target embeddings projected to unified space (i.e. ENU in
|
||||||
// multi-lingual case)
|
// multi-lingual case)
|
||||||
cli.add<std::string>("--ulr-keys-vectors",
|
cli.add<std::string>("--ulr-keys-vectors",
|
||||||
"Path to file with universal sources embeddings of traget keys from projection into universal space",
|
"Path to file with universal sources embeddings of target keys from projection into universal space",
|
||||||
"");
|
"");
|
||||||
cli.add<bool>("--ulr-trainable-transformation",
|
cli.add<bool>("--ulr-trainable-transformation",
|
||||||
"Make Query Transformation Matrix A trainable");
|
"Make Query Transformation Matrix A trainable");
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
|
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is necessary, remove if no problems occur.
|
||||||
#define NodeOp(op) [=]() { op; }
|
#define NodeOp(op) [=]() { op; }
|
||||||
|
|
||||||
// helper macro to disable optimization (gcc only)
|
// helper macro to disable optimization (gcc only)
|
||||||
|
@ -136,11 +136,11 @@ static void setErrorHandlers() {
|
|||||||
|
|
||||||
// modify the log pattern for the "general" logger to include the MPI rank
|
// modify the log pattern for the "general" logger to include the MPI rank
|
||||||
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
|
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
|
||||||
void switchtoMultinodeLogging(std::string nodeIdStr) {
|
void switchToMultinodeLogging(std::string nodeIdStr) {
|
||||||
Logger log = spdlog::get("general");
|
Logger log = spdlog::get("general");
|
||||||
if(log)
|
if(log)
|
||||||
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));
|
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));
|
||||||
|
|
||||||
Logger valid = spdlog::get("valid");
|
Logger valid = spdlog::get("valid");
|
||||||
if(valid)
|
if(valid)
|
||||||
valid->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] [valid] %v", nodeIdStr));
|
valid->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] [valid] %v", nodeIdStr));
|
||||||
|
@ -12,19 +12,19 @@ namespace marian {
|
|||||||
std::string getCallStack(size_t skipLevels);
|
std::string getCallStack(size_t skipLevels);
|
||||||
|
|
||||||
// Marian gives a basic exception guarantee. If you catch a
|
// Marian gives a basic exception guarantee. If you catch a
|
||||||
// MarianRuntimeError you must assume that the object can be
|
// MarianRuntimeError you must assume that the object can be
|
||||||
// safely destructed, but cannot be used otherwise.
|
// safely destructed, but cannot be used otherwise.
|
||||||
|
|
||||||
// Internal multi-threading in exception-throwing mode is not
|
// Internal multi-threading in exception-throwing mode is not
|
||||||
// allowed; and constructing a thread-pool will cause an exception.
|
// allowed; and constructing a thread-pool will cause an exception.
|
||||||
|
|
||||||
class MarianRuntimeException : public std::runtime_error {
|
class MarianRuntimeException : public std::runtime_error {
|
||||||
private:
|
private:
|
||||||
std::string callStack_;
|
std::string callStack_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MarianRuntimeException(const std::string& message, const std::string& callStack)
|
MarianRuntimeException(const std::string& message, const std::string& callStack)
|
||||||
: std::runtime_error(message),
|
: std::runtime_error(message),
|
||||||
callStack_(callStack) {}
|
callStack_(callStack) {}
|
||||||
|
|
||||||
const char* getCallStack() const throw() {
|
const char* getCallStack() const throw() {
|
||||||
@ -178,4 +178,4 @@ void checkedLog(std::string logger, std::string level, Args... args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void createLoggers(const marian::Config* options = nullptr);
|
void createLoggers(const marian::Config* options = nullptr);
|
||||||
void switchtoMultinodeLogging(std::string nodeIdStr);
|
void switchToMultinodeLogging(std::string nodeIdStr);
|
||||||
|
@ -98,7 +98,7 @@ public:
|
|||||||
* @brief Splice options from a YAML node
|
* @brief Splice options from a YAML node
|
||||||
*
|
*
|
||||||
* By default, only options with keys that do not already exist in options_ are extracted from
|
* By default, only options with keys that do not already exist in options_ are extracted from
|
||||||
* node. These options are cloned if overwirte is true.
|
* node. These options are cloned if overwrite is true.
|
||||||
*
|
*
|
||||||
* @param node a YAML node to transfer the options from
|
* @param node a YAML node to transfer the options from
|
||||||
* @param overwrite overwrite all options
|
* @param overwrite overwrite all options
|
||||||
|
@ -379,7 +379,7 @@ public:
|
|||||||
* @see marian::data::SubBatch::split(size_t n)
|
* @see marian::data::SubBatch::split(size_t n)
|
||||||
*/
|
*/
|
||||||
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
|
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
|
||||||
ABORT_IF(size() == 0, "Encoutered batch size of 0");
|
ABORT_IF(size() == 0, "Encountered batch size of 0");
|
||||||
|
|
||||||
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
|
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
|
||||||
// split each stream separately
|
// split each stream separately
|
||||||
@ -523,8 +523,8 @@ class CorpusBase : public DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch
|
|||||||
public:
|
public:
|
||||||
typedef SentenceTuple Sample;
|
typedef SentenceTuple Sample;
|
||||||
|
|
||||||
CorpusBase(Ptr<Options> options,
|
CorpusBase(Ptr<Options> options,
|
||||||
bool translate = false,
|
bool translate = false,
|
||||||
size_t seed = Config::seed);
|
size_t seed = Config::seed);
|
||||||
|
|
||||||
CorpusBase(const std::vector<std::string>& paths,
|
CorpusBase(const std::vector<std::string>& paths,
|
||||||
|
@ -44,7 +44,7 @@ public:
|
|||||||
virtual void prepare() {}
|
virtual void prepare() {}
|
||||||
virtual void restore(Ptr<TrainingState>) {}
|
virtual void restore(Ptr<TrainingState>) {}
|
||||||
|
|
||||||
// @TODO: remove after cleaning traininig/training.h
|
// @TODO: remove after cleaning training/training.h
|
||||||
virtual Ptr<Options> options() { return options_; }
|
virtual Ptr<Options> options() { return options_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ namespace marian {
|
|||||||
class IVocab;
|
class IVocab;
|
||||||
|
|
||||||
// Wrapper around vocabulary types. Can choose underlying
|
// Wrapper around vocabulary types. Can choose underlying
|
||||||
// vocabulary implementation (vImpl_) based on speficied path
|
// vocabulary implementation (vImpl_) based on specified path
|
||||||
// and suffix.
|
// and suffix.
|
||||||
// Vocabulary implementations can currently be:
|
// Vocabulary implementations can currently be:
|
||||||
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)
|
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)
|
||||||
|
@ -76,7 +76,7 @@ public:
|
|||||||
|
|
||||||
Ptr<Allocator> getAllocator() { return tensors_->allocator(); }
|
Ptr<Allocator> getAllocator() { return tensors_->allocator(); }
|
||||||
Ptr<TensorAllocator> getTensorAllocator() { return tensors_; }
|
Ptr<TensorAllocator> getTensorAllocator() { return tensors_; }
|
||||||
|
|
||||||
Expr findOrRemember(Expr node) {
|
Expr findOrRemember(Expr node) {
|
||||||
size_t hash = node->hash();
|
size_t hash = node->hash();
|
||||||
// memoize constant nodes that are not parameters
|
// memoize constant nodes that are not parameters
|
||||||
@ -359,9 +359,9 @@ private:
|
|||||||
|
|
||||||
// Find the named parameter and its typed parent parameter object (params) and return both.
|
// Find the named parameter and its typed parent parameter object (params) and return both.
|
||||||
// If the parameter is not found return the parent parameter object that the parameter should be added to.
|
// If the parameter is not found return the parent parameter object that the parameter should be added to.
|
||||||
// Return [nullptr, nullptr] if no matching parent parameter object exists.
|
// Return [nullptr, nullptr] if no matching parent parameter object exists.
|
||||||
std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
|
std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
|
||||||
Type elementType,
|
Type elementType,
|
||||||
bool typeSpecified) const {
|
bool typeSpecified) const {
|
||||||
Expr p; Ptr<Parameters> params;
|
Expr p; Ptr<Parameters> params;
|
||||||
if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type
|
if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type
|
||||||
@ -373,12 +373,12 @@ private:
|
|||||||
} else { // type has not been specified, so we take any type as long as the name matches
|
} else { // type has not been specified, so we take any type as long as the name matches
|
||||||
for(auto kvParams : paramsByElementType_) {
|
for(auto kvParams : paramsByElementType_) {
|
||||||
p = kvParams.second->get(name);
|
p = kvParams.second->get(name);
|
||||||
|
|
||||||
if(p) { // p has been found, return with matching params object
|
if(p) { // p has been found, return with matching params object
|
||||||
params = kvParams.second;
|
params = kvParams.second;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned
|
if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned
|
||||||
params = kvParams.second;
|
params = kvParams.second;
|
||||||
}
|
}
|
||||||
@ -399,8 +399,8 @@ private:
|
|||||||
|
|
||||||
Expr p; Ptr<Parameters> params; std::tie
|
Expr p; Ptr<Parameters> params; std::tie
|
||||||
(p, params) = findParams(name, elementType, typeSpecified);
|
(p, params) = findParams(name, elementType, typeSpecified);
|
||||||
|
|
||||||
if(!params) {
|
if(!params) {
|
||||||
params = New<Parameters>(elementType);
|
params = New<Parameters>(elementType);
|
||||||
params->init(backend_);
|
params->init(backend_);
|
||||||
paramsByElementType_.insert({elementType, params});
|
paramsByElementType_.insert({elementType, params});
|
||||||
@ -632,13 +632,13 @@ public:
|
|||||||
* Return the Parameters object related to the graph.
|
* Return the Parameters object related to the graph.
|
||||||
* The Parameters object holds the whole set of the parameter nodes.
|
* The Parameters object holds the whole set of the parameter nodes.
|
||||||
*/
|
*/
|
||||||
Ptr<Parameters>& params() {
|
Ptr<Parameters>& params() {
|
||||||
// There are no parameter objects, that's weird.
|
// There are no parameter objects, that's weird.
|
||||||
ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created");
|
ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created");
|
||||||
|
|
||||||
// Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported
|
// Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported
|
||||||
ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size());
|
ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size());
|
||||||
|
|
||||||
// Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported
|
// Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported
|
||||||
auto it = paramsByElementType_.find(defaultElementType_);
|
auto it = paramsByElementType_.find(defaultElementType_);
|
||||||
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
|
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
|
||||||
@ -650,7 +650,7 @@ public:
|
|||||||
* Return the Parameters object related to the graph by elementType.
|
* Return the Parameters object related to the graph by elementType.
|
||||||
* The Parameters object holds the whole set of the parameter nodes of the given type.
|
* The Parameters object holds the whole set of the parameter nodes of the given type.
|
||||||
*/
|
*/
|
||||||
Ptr<Parameters>& params(Type elementType) {
|
Ptr<Parameters>& params(Type elementType) {
|
||||||
auto it = paramsByElementType_.find(elementType);
|
auto it = paramsByElementType_.find(elementType);
|
||||||
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
|
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
|
||||||
return it->second;
|
return it->second;
|
||||||
@ -661,8 +661,8 @@ public:
|
|||||||
* The default value is used if some node type is not specified.
|
* The default value is used if some node type is not specified.
|
||||||
*/
|
*/
|
||||||
void setDefaultElementType(Type defaultElementType) {
|
void setDefaultElementType(Type defaultElementType) {
|
||||||
ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
|
ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
|
||||||
"Parameter objects already exist, cannot change default type from {} to {}",
|
"Parameter objects already exist, cannot change default type from {} to {}",
|
||||||
defaultElementType_, defaultElementType);
|
defaultElementType_, defaultElementType);
|
||||||
defaultElementType_ = defaultElementType;
|
defaultElementType_ = defaultElementType;
|
||||||
}
|
}
|
||||||
@ -746,7 +746,7 @@ public:
|
|||||||
// skip over special parameters starting with "special:"
|
// skip over special parameters starting with "special:"
|
||||||
if(pName.substr(0, 8) == "special:")
|
if(pName.substr(0, 8) == "special:")
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// if during loading the loaded type is of the same type class as the default element type, allow conversion;
|
// if during loading the loaded type is of the same type class as the default element type, allow conversion;
|
||||||
// otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both
|
// otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both
|
||||||
// have type class TypeClass::float_type.
|
// have type class TypeClass::float_type.
|
||||||
@ -781,9 +781,9 @@ public:
|
|||||||
|
|
||||||
LOG(info, "Memory mapping model at {}", ptr);
|
LOG(info, "Memory mapping model at {}", ptr);
|
||||||
auto items = io::mmapItems(ptr);
|
auto items = io::mmapItems(ptr);
|
||||||
|
|
||||||
// Deal with default parameter set object that might not be a mapped object.
|
// Deal with default parameter set object that might not be a mapped object.
|
||||||
// This gets assigned during ExpressionGraph::setDevice(...) and by default
|
// This gets assigned during ExpressionGraph::setDevice(...) and by default
|
||||||
// would contain allocated tensors. Here we replace it with a mmapped version.
|
// would contain allocated tensors. Here we replace it with a mmapped version.
|
||||||
auto it = paramsByElementType_.find(defaultElementType_);
|
auto it = paramsByElementType_.find(defaultElementType_);
|
||||||
if(it != paramsByElementType_.end()) {
|
if(it != paramsByElementType_.end()) {
|
||||||
|
@ -27,12 +27,12 @@ Expr checkpoint(Expr a) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||||
LambdaNodeFunctor fwd, size_t hash) {
|
LambdaNodeFunctor fwd, size_t hash) {
|
||||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||||
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
|
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
|
||||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
|
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
|
||||||
}
|
}
|
||||||
@ -436,7 +436,7 @@ Expr std(Expr a, int ax) {
|
|||||||
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
|
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr var(Expr a, int ax) {
|
Expr var(Expr a, int ax) {
|
||||||
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
|
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
|
||||||
return a - a;
|
return a - a;
|
||||||
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
|
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
|
||||||
@ -575,8 +575,8 @@ Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float sc
|
|||||||
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
|
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
|
// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
|
||||||
// The last branch with auto-tuner is:
|
// The last branch with auto-tuner is:
|
||||||
// youki/packed-model-pr-backup1031
|
// youki/packed-model-pr-backup1031
|
||||||
// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
|
// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
|
||||||
// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
|
// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
|
||||||
@ -660,8 +660,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Default GEMM
|
// Default GEMM
|
||||||
ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
|
ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
|
||||||
"GPU-based GEMM only supports float types, you have A: {} and B: {}",
|
"GPU-based GEMM only supports float types, you have A: {} and B: {}",
|
||||||
aElementType, bElementType);
|
aElementType, bElementType);
|
||||||
return affineDefault(a, b, bias, transA, transB, scale);
|
return affineDefault(a, b, bias, transA, transB, scale);
|
||||||
}
|
}
|
||||||
@ -669,7 +669,7 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
|||||||
|
|
||||||
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||||
auto graph = a->graph();
|
auto graph = a->graph();
|
||||||
|
|
||||||
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
|
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
|
||||||
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
|
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
|
||||||
else
|
else
|
||||||
@ -775,7 +775,7 @@ Expr unlikelihood(Expr logits, Expr indices) {
|
|||||||
int dimBatch = logits->shape()[-2];
|
int dimBatch = logits->shape()[-2];
|
||||||
int dimTime = logits->shape()[-3];
|
int dimTime = logits->shape()[-3];
|
||||||
|
|
||||||
// @TODO: fix this outside of this function in decoder.h etc.
|
// @TODO: fix this outside of this function in decoder.h etc.
|
||||||
auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
|
auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
|
||||||
|
|
||||||
// This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy
|
// This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy
|
||||||
|
@ -70,9 +70,9 @@ public:
|
|||||||
outputs.push_back(output2);
|
outputs.push_back(output2);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concated = concatenate(outputs, -1);
|
auto concatenated = concatenate(outputs, -1);
|
||||||
|
|
||||||
return concated;
|
return concatenated;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -67,7 +67,7 @@ public:
|
|||||||
return count_->val()->scalar<T>();
|
return count_->val()->scalar<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// @TODO: add a funtion for returning maybe ratio?
|
// @TODO: add a function for returning maybe ratio?
|
||||||
|
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
ABORT_IF(!count_, "Labels have not been defined");
|
ABORT_IF(!count_, "Labels have not been defined");
|
||||||
@ -189,7 +189,7 @@ public:
|
|||||||
*
|
*
|
||||||
* L = sum_i^N L_i + N/M sum_j^M L_j
|
* L = sum_i^N L_i + N/M sum_j^M L_j
|
||||||
*
|
*
|
||||||
* We set labels to N. When reporting L/N this is equvalient to sum of means.
|
* We set labels to N. When reporting L/N this is equivalent to sum of means.
|
||||||
* Compare to sum of means below where N is factored into the loss, but labels
|
* Compare to sum of means below where N is factored into the loss, but labels
|
||||||
* are set to 1.
|
* are set to 1.
|
||||||
*/
|
*/
|
||||||
|
@ -76,7 +76,7 @@ private:
|
|||||||
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
|
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
|
||||||
|
|
||||||
// @TODO: switch to new random generator back-end.
|
// @TODO: switch to new random generator back-end.
|
||||||
// This is rarly used however.
|
// This is rarely used however.
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
std::mt19937 engine(rd());
|
std::mt19937 engine(rd());
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This file contains nearly all BERT-related code and adds BERT-funtionality
|
* This file contains nearly all BERT-related code and adds BERT-functionality
|
||||||
* on top of existing classes like TansformerEncoder and Classifier.
|
* on top of existing classes like TansformerEncoder and Classifier.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ public:
|
|||||||
// Initialize to sample random vocab id
|
// Initialize to sample random vocab id
|
||||||
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
|
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
|
||||||
|
|
||||||
// Intialize to sample random percentage
|
// Initialize to sample random percentage
|
||||||
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
|
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
|
||||||
|
|
||||||
auto& words = subBatch->data();
|
auto& words = subBatch->data();
|
||||||
|
@ -14,7 +14,7 @@ namespace marian {
|
|||||||
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
|
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
|
||||||
* Already has support for multi-objective training.
|
* Already has support for multi-objective training.
|
||||||
*
|
*
|
||||||
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
|
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for decoder/classifier
|
||||||
* multi-objective training.
|
* multi-objective training.
|
||||||
*/
|
*/
|
||||||
class EncoderClassifierBase : public models::IModel {
|
class EncoderClassifierBase : public models::IModel {
|
||||||
|
@ -220,7 +220,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
|
|||||||
if(clearGraph)
|
if(clearGraph)
|
||||||
clear(graph);
|
clear(graph);
|
||||||
|
|
||||||
// Required first step, also intializes shortlist
|
// Required first step, also initializes shortlist
|
||||||
auto state = startState(graph, batch);
|
auto state = startState(graph, batch);
|
||||||
|
|
||||||
// Fill state with embeddings from batch (ground truth)
|
// Fill state with embeddings from batch (ground truth)
|
||||||
|
@ -70,7 +70,7 @@ public:
|
|||||||
// Hack for translating with length longer than trained embeddings
|
// Hack for translating with length longer than trained embeddings
|
||||||
// We check if the embedding matrix "Wpos" already exist so we can
|
// We check if the embedding matrix "Wpos" already exist so we can
|
||||||
// check the number of positions in that loaded parameter.
|
// check the number of positions in that loaded parameter.
|
||||||
// We then have to restict the maximum length to the maximum positon
|
// We then have to restrict the maximum length to the maximum positon
|
||||||
// and positions beyond this will be the maximum position.
|
// and positions beyond this will be the maximum position.
|
||||||
Expr seenEmb = graph_->get("Wpos");
|
Expr seenEmb = graph_->get("Wpos");
|
||||||
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;
|
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;
|
||||||
|
@ -101,7 +101,7 @@ public:
|
|||||||
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
|
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
|
||||||
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
|
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
|
||||||
rankStr.insert(rankStr.begin(), ' ');
|
rankStr.insert(rankStr.begin(), ' ');
|
||||||
switchtoMultinodeLogging(rankStr);
|
switchToMultinodeLogging(rankStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// log hostnames in order, and test
|
// log hostnames in order, and test
|
||||||
@ -261,7 +261,7 @@ void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
|
|||||||
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
|
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
|
||||||
mpi = nullptr; // destruct caller's handle
|
mpi = nullptr; // destruct caller's handle
|
||||||
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
|
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
|
||||||
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation
|
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
|
||||||
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
|
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
|
||||||
s_mpi->finalize(); // signal successful completion to MPI
|
s_mpi->finalize(); // signal successful completion to MPI
|
||||||
s_mpi = nullptr; // release the singleton instance upon last finalization
|
s_mpi = nullptr; // release the singleton instance upon last finalization
|
||||||
|
@ -13,7 +13,7 @@ namespace marian {
|
|||||||
|
|
||||||
// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
|
// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
|
||||||
static inline bool isFinite(float x) {
|
static inline bool isFinite(float x) {
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
|
ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
|
||||||
#endif
|
#endif
|
||||||
return std::isfinite(x);
|
return std::isfinite(x);
|
||||||
@ -27,7 +27,7 @@ static inline bool isFinite(float x) {
|
|||||||
// if one value is nonfinite propagate Nan into the reduction.
|
// if one value is nonfinite propagate Nan into the reduction.
|
||||||
static inline void accNanOrNorm(float& lhs, float rhs) {
|
static inline void accNanOrNorm(float& lhs, float rhs) {
|
||||||
if(isFinite(lhs) && isFinite(rhs)) {
|
if(isFinite(lhs) && isFinite(rhs)) {
|
||||||
lhs = sqrtf(lhs * lhs + rhs * rhs);
|
lhs = sqrtf(lhs * lhs + rhs * rhs);
|
||||||
} else
|
} else
|
||||||
lhs = std::numeric_limits<float>::quiet_NaN();
|
lhs = std::numeric_limits<float>::quiet_NaN();
|
||||||
}
|
}
|
||||||
@ -42,20 +42,20 @@ static inline void accNanOrNorm(float& lhs, float rhs) {
|
|||||||
class GraphGroup {
|
class GraphGroup {
|
||||||
protected:
|
protected:
|
||||||
Ptr<Options> options_;
|
Ptr<Options> options_;
|
||||||
|
|
||||||
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
|
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
|
||||||
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
|
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
|
||||||
|
|
||||||
std::vector<DeviceId> devices_; // [deviceIndex]
|
std::vector<DeviceId> devices_; // [deviceIndex]
|
||||||
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
|
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
|
||||||
|
|
||||||
// common for all graph groups, individual graph groups decide how to fill them
|
// common for all graph groups, individual graph groups decide how to fill them
|
||||||
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
|
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
|
||||||
std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
|
std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
|
||||||
std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex]
|
std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex]
|
||||||
|
|
||||||
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
|
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
|
||||||
|
|
||||||
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
|
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
|
||||||
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
|
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
|
||||||
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
|
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
|
||||||
@ -100,16 +100,16 @@ public:
|
|||||||
|
|
||||||
virtual void load();
|
virtual void load();
|
||||||
virtual void save(bool isFinal = false);
|
virtual void save(bool isFinal = false);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void load(const OptimizerBase::ScatterStateFunc& scatterFn);
|
void load(const OptimizerBase::ScatterStateFunc& scatterFn);
|
||||||
void save(bool isFinal,
|
void save(bool isFinal,
|
||||||
const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);
|
const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);
|
||||||
|
|
||||||
bool restoreFromCheckpoint(const std::string& modelFileName,
|
bool restoreFromCheckpoint(const std::string& modelFileName,
|
||||||
const OptimizerBase::ScatterStateFunc& scatterFn);
|
const OptimizerBase::ScatterStateFunc& scatterFn);
|
||||||
|
|
||||||
void saveCheckpoint(const std::string& modelFileName,
|
void saveCheckpoint(const std::string& modelFileName,
|
||||||
const OptimizerBase::GatherStateFunc& gatherFn);
|
const OptimizerBase::GatherStateFunc& gatherFn);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -128,11 +128,11 @@ public:
|
|||||||
float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);
|
float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);
|
||||||
|
|
||||||
float computeNormalizationFactor(float gNorm, size_t updateTrgWords);
|
float computeNormalizationFactor(float gNorm, size_t updateTrgWords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Determine maximal batch size that can fit into the given workspace
|
* Determine maximal batch size that can fit into the given workspace
|
||||||
* so that reallocation does not happen. Rather adjust the batch size
|
* so that reallocation does not happen. Rather adjust the batch size
|
||||||
* based on the stastistics collected here. Activated with
|
* based on the statistics collected here. Activated with
|
||||||
* `--mini-batch-fit`.
|
* `--mini-batch-fit`.
|
||||||
* In a multi-GPU scenario, the first GPU is used to determine the size.
|
* In a multi-GPU scenario, the first GPU is used to determine the size.
|
||||||
* The actual allowed size is then determined by multiplying it with the
|
* The actual allowed size is then determined by multiplying it with the
|
||||||
@ -151,4 +151,4 @@ public:
|
|||||||
void updateAverageTrgBatchWords(size_t trgBatchWords);
|
void updateAverageTrgBatchWords(size_t trgBatchWords);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
Loading…
Reference in New Issue
Block a user