Cherry picked cleaning/refeactoring patches (#905)

Cherry-picked updates from pull request #457

Co-authored-by: Mateusz Chudyk <mateuszchudyk@gmail.com>
This commit is contained in:
Roman Grundkiewicz 2022-01-28 14:16:41 +00:00 committed by GitHub
parent 71b5454b9e
commit 07c39c7d76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 113 additions and 108 deletions

View File

@ -113,10 +113,10 @@ std::string CLIWrapper::switchGroup(std::string name) {
return name;
}
void CLIWrapper::parse(int argc, char **argv) {
void CLIWrapper::parse(int argc, char** argv) {
try {
app_->parse(argc, argv);
} catch(const CLI::ParseError &e) {
} catch(const CLI::ParseError& e) {
exit(app_->exit(e));
}
@ -182,6 +182,13 @@ void CLIWrapper::parseAliases() {
}
}
std::string CLIWrapper::keyName(const std::string& args) const {
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
auto cmdOptions = getParsedOptionNames();
// Keep track of unrecognized options from the provided config
@ -276,7 +283,7 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
for(auto const &it : options_)
keys.push_back(it.first);
// sort option names by creation index
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) {
sort(keys.begin(), keys.end(), [this](const std::string& a, const std::string& b) {
return options_.at(a).idx < options_.at(b).idx;
});
return keys;

View File

@ -44,7 +44,7 @@ struct CLIAliasTuple {
class CLIFormatter : public CLI::Formatter {
public:
CLIFormatter(size_t columnWidth, size_t screenWidth);
virtual std::string make_option_desc(const CLI::Option *) const override;
virtual std::string make_option_desc(const CLI::Option*) const override;
private:
size_t screenWidth_{0};
@ -85,12 +85,7 @@ private:
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
// '--help,-h'
std::string keyName(const std::string &args) const {
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
std::string keyName(const std::string &args) const;
// Get names of options passed via command-line
std::unordered_set<std::string> getParsedOptionNames() const;
@ -134,7 +129,7 @@ public:
* @return Option object
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help, T val) {
CLI::Option* add(const std::string& args, const std::string& help, T val) {
return addOption<T>(keyName(args),
args,
help,
@ -159,7 +154,7 @@ public:
* @TODO: require to always state the default value creating the parser as this will be clearer
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help) {
CLI::Option* add(const std::string& args, const std::string& help) {
return addOption<T>(keyName(args),
args,
help,
@ -206,7 +201,7 @@ public:
std::string switchGroup(std::string name = "");
// Parse command-line arguments. Handles --help and --version options
void parse(int argc, char **argv);
void parse(int argc, char** argv);
/**
* @brief Expand aliases based on arguments parsed with parse(int, char**)
@ -240,11 +235,12 @@ public:
std::string dumpConfig(bool skipUnmodified = false) const;
private:
template <typename T,
// options with numeric and string-like values
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfNumbericOrString = CLI::enable_if_t<!CLI::is_bool<T>::value
&& !CLI::is_vector<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfNumbericOrString<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,
@ -261,7 +257,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine;
// get variable associated with the option
auto &var = options_[key].var->as<T>();
auto& var = options_[key].var->as<T>();
// store parser result in var
auto ret = CLI::detail::lexical_cast(res[0], var);
// update YAML entry
@ -288,10 +284,11 @@ private:
return options_[key].opt;
}
template <typename T,
// options with vector values
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfVector = CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfVector<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,
@ -308,7 +305,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine;
// get vector variable associated with the option
auto &vec = options_[key].var->as<T>();
auto& vec = options_[key].var->as<T>();
vec.clear();
bool ret = true;
// handle '[]' as an empty vector
@ -316,7 +313,7 @@ private:
ret = true;
} else {
// populate the vector with parser results
for(const auto &a : res) {
for(const auto& a : res) {
vec.emplace_back();
ret &= CLI::detail::lexical_cast(a, vec.back());
}
@ -345,10 +342,11 @@ private:
return options_[key].opt;
}
template <typename T,
// options with boolean values, called flags in CLI11
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfBoolean = CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfBoolean<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,

View File

@ -107,7 +107,7 @@ private:
* @param mode change the set of available command-line options, e.g. training, translation, etc.
* @param validate validate parsed options and abort on failure
*
* @return parsed otions
* @return parsed options
*/
Ptr<Options> parseOptions(int argc,
char** argv,

View File

@ -119,10 +119,10 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
cli.add<std::vector<std::string>>("--config,-c",
"Configuration file(s). If multiple, later overrides earlier");
cli.add<size_t>("--workspace,-w",
"Preallocate arg MB of work space",
"Preallocate arg MB of work space",
defaultWorkspace);
cli.add<std::string>("--log",
"Log training process information to file given by arg");
"Log training process information to file given by arg");
cli.add<std::string>("--log-level",
"Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
"info");
@ -392,17 +392,17 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Finish after this many chosen training units, 0 is infinity (e.g. 100e = 100 epochs, 10Gt = 10 billion target labels, 100Ku = 100,000 updates",
"0e");
cli.add<std::string/*SchedulerPeriod*/>("--disp-freq",
"Display information every arg updates (append 't' for every arg target labels)",
"Display information every arg updates (append 't' for every arg target labels)",
"1000u");
cli.add<size_t>("--disp-first",
"Display information for the first arg updates");
"Display information for the first arg updates");
cli.add<bool>("--disp-label-counts",
"Display label counts when logging loss progress",
true);
// cli.add<int>("--disp-label-index",
// "Display label counts based on i-th input stream (-1 is last)", -1);
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
"Save model file every arg updates (append 't' for every arg target labels)",
"Save model file every arg updates (append 't' for every arg target labels)",
"10000u");
cli.add<std::vector<std::string>>("--logical-epoch",
"Redefine logical epoch counter as multiple of data epochs (e.g. 1e), updates (e.g. 100Ku) or labels (e.g. 1Gt). "
@ -473,12 +473,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<bool>("--lr-decay-repeat-warmup",
"Repeat learning rate warmup when learning rate is decayed");
cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
"Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
"Add second argument to define the starting point (default: same as first value)",
{"0"});
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
"0");
cli.add<float>("--lr-warmup-start-rate",
"Start value for learning rate warmup");
@ -492,7 +492,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<double>("--factor-weight",
"Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
cli.add<float>("--clip-norm",
"Clip gradient norm to arg (0 to disable)",
"Clip gradient norm to arg (0 to disable)",
1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
cli.add<float>("--exponential-smoothing",
"Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
@ -575,7 +575,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<std::vector<std::string>>("--valid-sets",
"Paths to validation corpora: source target");
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
"Validate model every arg updates (append 't' for every arg target labels)",
"Validate model every arg updates (append 't' for every arg target labels)",
"10000u");
cli.add<std::vector<std::string>>("--valid-metrics",
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
@ -585,7 +585,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<bool>("--valid-reset-stalled",
"Reset all stalled validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);
cli.add<std::string>("--early-stopping-on",
"Decide if early stopping should take into account first, all, or any validation metrics"
@ -637,7 +637,7 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.add<bool>("--keep-best",
"Keep best model for each validation metric");
cli.add<std::string>("--valid-log",
"Log validation scores to file given by arg");
"Log validation scores to file given by arg");
cli.switchGroup(previous_group);
// clang-format on
}
@ -942,10 +942,10 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
cli.add<std::string>("--ulr-query-vectors",
"Path to file with universal sources embeddings from projection into universal space",
"");
// keys: EK in Fig2 : is the keys of the target embbedings projected to unified space (i.e. ENU in
// keys: EK in Fig2 : is the keys of the target embeddings projected to unified space (i.e. ENU in
// multi-lingual case)
cli.add<std::string>("--ulr-keys-vectors",
"Path to file with universal sources embeddings of traget keys from projection into universal space",
"Path to file with universal sources embeddings of target keys from projection into universal space",
"");
cli.add<bool>("--ulr-trainable-transformation",
"Make Query Transformation Matrix A trainable");

View File

@ -10,7 +10,7 @@
#include <string>
#include <vector>
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is necessary, remove if no problems occur.
#define NodeOp(op) [=]() { op; }
// helper macro to disable optimization (gcc only)

View File

@ -136,11 +136,11 @@ static void setErrorHandlers() {
// modify the log pattern for the "general" logger to include the MPI rank
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
void switchtoMultinodeLogging(std::string nodeIdStr) {
void switchToMultinodeLogging(std::string nodeIdStr) {
Logger log = spdlog::get("general");
if(log)
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));
Logger valid = spdlog::get("valid");
if(valid)
valid->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] [valid] %v", nodeIdStr));

View File

@ -12,19 +12,19 @@ namespace marian {
std::string getCallStack(size_t skipLevels);
// Marian gives a basic exception guarantee. If you catch a
// MarianRuntimeError you must assume that the object can be
// MarianRuntimeError you must assume that the object can be
// safely destructed, but cannot be used otherwise.
// Internal multi-threading in exception-throwing mode is not
// Internal multi-threading in exception-throwing mode is not
// allowed; and constructing a thread-pool will cause an exception.
class MarianRuntimeException : public std::runtime_error {
private:
std::string callStack_;
public:
MarianRuntimeException(const std::string& message, const std::string& callStack)
: std::runtime_error(message),
MarianRuntimeException(const std::string& message, const std::string& callStack)
: std::runtime_error(message),
callStack_(callStack) {}
const char* getCallStack() const throw() {
@ -178,4 +178,4 @@ void checkedLog(std::string logger, std::string level, Args... args) {
}
void createLoggers(const marian::Config* options = nullptr);
void switchtoMultinodeLogging(std::string nodeIdStr);
void switchToMultinodeLogging(std::string nodeIdStr);

View File

@ -98,7 +98,7 @@ public:
* @brief Splice options from a YAML node
*
* By default, only options with keys that do not already exist in options_ are extracted from
* node. These options are cloned if overwirte is true.
* node. These options are cloned if overwrite is true.
*
* @param node a YAML node to transfer the options from
* @param overwrite overwrite all options

View File

@ -379,7 +379,7 @@ public:
* @see marian::data::SubBatch::split(size_t n)
*/
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
ABORT_IF(size() == 0, "Encoutered batch size of 0");
ABORT_IF(size() == 0, "Encountered batch size of 0");
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
// split each stream separately
@ -523,8 +523,8 @@ class CorpusBase : public DatasetBase<SentenceTuple, CorpusIterator, CorpusBatch
public:
typedef SentenceTuple Sample;
CorpusBase(Ptr<Options> options,
bool translate = false,
CorpusBase(Ptr<Options> options,
bool translate = false,
size_t seed = Config::seed);
CorpusBase(const std::vector<std::string>& paths,

View File

@ -44,7 +44,7 @@ public:
virtual void prepare() {}
virtual void restore(Ptr<TrainingState>) {}
// @TODO: remove after cleaning traininig/training.h
// @TODO: remove after cleaning training/training.h
virtual Ptr<Options> options() { return options_; }
};

View File

@ -10,7 +10,7 @@ namespace marian {
class IVocab;
// Wrapper around vocabulary types. Can choose underlying
// vocabulary implementation (vImpl_) based on speficied path
// vocabulary implementation (vImpl_) based on specified path
// and suffix.
// Vocabulary implementations can currently be:
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)

View File

@ -76,7 +76,7 @@ public:
Ptr<Allocator> getAllocator() { return tensors_->allocator(); }
Ptr<TensorAllocator> getTensorAllocator() { return tensors_; }
Expr findOrRemember(Expr node) {
size_t hash = node->hash();
// memoize constant nodes that are not parameters
@ -359,9 +359,9 @@ private:
// Find the named parameter and its typed parent parameter object (params) and return both.
// If the parameter is not found return the parent parameter object that the parameter should be added to.
// Return [nullptr, nullptr] if no matching parent parameter object exists.
std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
Type elementType,
// Return [nullptr, nullptr] if no matching parent parameter object exists.
std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
Type elementType,
bool typeSpecified) const {
Expr p; Ptr<Parameters> params;
if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type
@ -373,12 +373,12 @@ private:
} else { // type has not been specified, so we take any type as long as the name matches
for(auto kvParams : paramsByElementType_) {
p = kvParams.second->get(name);
if(p) { // p has been found, return with matching params object
params = kvParams.second;
break;
}
if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned
params = kvParams.second;
}
@ -399,8 +399,8 @@ private:
Expr p; Ptr<Parameters> params; std::tie
(p, params) = findParams(name, elementType, typeSpecified);
if(!params) {
if(!params) {
params = New<Parameters>(elementType);
params->init(backend_);
paramsByElementType_.insert({elementType, params});
@ -632,13 +632,13 @@ public:
* Return the Parameters object related to the graph.
* The Parameters object holds the whole set of the parameter nodes.
*/
Ptr<Parameters>& params() {
Ptr<Parameters>& params() {
// There are no parameter objects, that's weird.
ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created");
// Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported
ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size());
// Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported
auto it = paramsByElementType_.find(defaultElementType_);
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
@ -650,7 +650,7 @@ public:
* Return the Parameters object related to the graph by elementType.
* The Parameters object holds the whole set of the parameter nodes of the given type.
*/
Ptr<Parameters>& params(Type elementType) {
Ptr<Parameters>& params(Type elementType) {
auto it = paramsByElementType_.find(elementType);
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
return it->second;
@ -661,8 +661,8 @@ public:
* The default value is used if some node type is not specified.
*/
void setDefaultElementType(Type defaultElementType) {
ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
"Parameter objects already exist, cannot change default type from {} to {}",
ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
"Parameter objects already exist, cannot change default type from {} to {}",
defaultElementType_, defaultElementType);
defaultElementType_ = defaultElementType;
}
@ -746,7 +746,7 @@ public:
// skip over special parameters starting with "special:"
if(pName.substr(0, 8) == "special:")
continue;
// if during loading the loaded type is of the same type class as the default element type, allow conversion;
// otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both
// have type class TypeClass::float_type.
@ -781,9 +781,9 @@ public:
LOG(info, "Memory mapping model at {}", ptr);
auto items = io::mmapItems(ptr);
// Deal with default parameter set object that might not be a mapped object.
// This gets assigned during ExpressionGraph::setDevice(...) and by default
// This gets assigned during ExpressionGraph::setDevice(...) and by default
// would contain allocated tensors. Here we replace it with a mmapped version.
auto it = paramsByElementType_.find(defaultElementType_);
if(it != paramsByElementType_.end()) {

View File

@ -27,12 +27,12 @@ Expr checkpoint(Expr a) {
return a;
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
}
@ -436,7 +436,7 @@ Expr std(Expr a, int ax) {
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
}
Expr var(Expr a, int ax) {
Expr var(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
return a - a;
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
@ -575,8 +575,8 @@ Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float sc
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
}
// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
// The last branch with auto-tuner is:
// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
// The last branch with auto-tuner is:
// youki/packed-model-pr-backup1031
// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
@ -660,8 +660,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
} else {
// Default GEMM
ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
"GPU-based GEMM only supports float types, you have A: {} and B: {}",
ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
"GPU-based GEMM only supports float types, you have A: {} and B: {}",
aElementType, bElementType);
return affineDefault(a, b, bias, transA, transB, scale);
}
@ -669,7 +669,7 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto graph = a->graph();
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
else
@ -775,7 +775,7 @@ Expr unlikelihood(Expr logits, Expr indices) {
int dimBatch = logits->shape()[-2];
int dimTime = logits->shape()[-3];
// @TODO: fix this outside of this function in decoder.h etc.
// @TODO: fix this outside of this function in decoder.h etc.
auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
// This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy

View File

@ -70,9 +70,9 @@ public:
outputs.push_back(output2);
}
auto concated = concatenate(outputs, -1);
auto concatenated = concatenate(outputs, -1);
return concated;
return concatenated;
}
protected:

View File

@ -67,7 +67,7 @@ public:
return count_->val()->scalar<T>();
}
// @TODO: add a funtion for returning maybe ratio?
// @TODO: add a function for returning maybe ratio?
size_t size() const {
ABORT_IF(!count_, "Labels have not been defined");
@ -189,7 +189,7 @@ public:
*
* L = sum_i^N L_i + N/M sum_j^M L_j
*
* We set labels to N. When reporting L/N this is equvalient to sum of means.
* We set labels to N. When reporting L/N this is equivalent to sum of means.
* Compare to sum of means below where N is factored into the loss, but labels
* are set to 1.
*/

View File

@ -76,7 +76,7 @@ private:
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
// @TODO: switch to new random generator back-end.
// This is rarly used however.
// This is rarely used however.
std::random_device rd;
std::mt19937 engine(rd());

View File

@ -8,7 +8,7 @@
namespace marian {
/**
* This file contains nearly all BERT-related code and adds BERT-funtionality
* This file contains nearly all BERT-related code and adds BERT-functionality
* on top of existing classes like TansformerEncoder and Classifier.
*/
@ -82,7 +82,7 @@ public:
// Initialize to sample random vocab id
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
// Intialize to sample random percentage
// Initialize to sample random percentage
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
auto& words = subBatch->data();

View File

@ -14,7 +14,7 @@ namespace marian {
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
* Already has support for multi-objective training.
*
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for decoder/classifier
* multi-objective training.
*/
class EncoderClassifierBase : public models::IModel {

View File

@ -220,7 +220,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
if(clearGraph)
clear(graph);
// Required first step, also intializes shortlist
// Required first step, also initializes shortlist
auto state = startState(graph, batch);
// Fill state with embeddings from batch (ground truth)

View File

@ -70,7 +70,7 @@ public:
// Hack for translating with length longer than trained embeddings
// We check if the embedding matrix "Wpos" already exist so we can
// check the number of positions in that loaded parameter.
// We then have to restict the maximum length to the maximum positon
// We then have to restrict the maximum length to the maximum positon
// and positions beyond this will be the maximum position.
Expr seenEmb = graph_->get("Wpos");
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;

View File

@ -101,7 +101,7 @@ public:
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
rankStr.insert(rankStr.begin(), ' ');
switchtoMultinodeLogging(rankStr);
switchToMultinodeLogging(rankStr);
}
// log hostnames in order, and test
@ -261,7 +261,7 @@ void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
mpi = nullptr; // destruct caller's handle
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
s_mpi->finalize(); // signal successful completion to MPI
s_mpi = nullptr; // release the singleton instance upon last finalization

View File

@ -13,7 +13,7 @@ namespace marian {
// With -Ofast enabled gcc will fail to identify NaN or Inf. Safeguard here.
static inline bool isFinite(float x) {
#ifdef __GNUC__
#ifdef __GNUC__
ABORT_IF(std::isfinite(0.f / 0.f), "NaN detection unreliable. Disable -Ofast compiler option.");
#endif
return std::isfinite(x);
@ -27,7 +27,7 @@ static inline bool isFinite(float x) {
// if one value is nonfinite propagate Nan into the reduction.
static inline void accNanOrNorm(float& lhs, float rhs) {
if(isFinite(lhs) && isFinite(rhs)) {
lhs = sqrtf(lhs * lhs + rhs * rhs);
lhs = sqrtf(lhs * lhs + rhs * rhs);
} else
lhs = std::numeric_limits<float>::quiet_NaN();
}
@ -42,20 +42,20 @@ static inline void accNanOrNorm(float& lhs, float rhs) {
class GraphGroup {
protected:
Ptr<Options> options_;
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
std::vector<DeviceId> devices_; // [deviceIndex]
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
ShardingMode shardingMode_{ShardingMode::global}; // If local and multi-node training, shard only on local devices and do full sync (faster). If global shard across entire set of GPUs (more RAM).
// common for all graph groups, individual graph groups decide how to fill them
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<models::ICriterionFunction>> models_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> optimizerShards_; // [deviceIndex]
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
double typicalTrgBatchWords_{0}; // for dynamic batch sizing: typical batch size in words
bool mbRoundUp_{true}; // round up batches for more efficient training but can make batch size less stable, disable with --mini-batch-round-up=false
@ -100,16 +100,16 @@ public:
virtual void load();
virtual void save(bool isFinal = false);
private:
void load(const OptimizerBase::ScatterStateFunc& scatterFn);
void save(bool isFinal,
const OptimizerBase::GatherStateFunc& gatherOptimizerStateFn);
bool restoreFromCheckpoint(const std::string& modelFileName,
bool restoreFromCheckpoint(const std::string& modelFileName,
const OptimizerBase::ScatterStateFunc& scatterFn);
void saveCheckpoint(const std::string& modelFileName,
void saveCheckpoint(const std::string& modelFileName,
const OptimizerBase::GatherStateFunc& gatherFn);
public:
@ -128,11 +128,11 @@ public:
float executeAndCollectNorm(const std::function<float(size_t, size_t, size_t)>& task);
float computeNormalizationFactor(float gNorm, size_t updateTrgWords);
/**
* Determine maximal batch size that can fit into the given workspace
* so that reallocation does not happen. Rather adjust the batch size
* based on the stastistics collected here. Activated with
* based on the statistics collected here. Activated with
* `--mini-batch-fit`.
* In a multi-GPU scenario, the first GPU is used to determine the size.
* The actual allowed size is then determined by multiplying it with the
@ -151,4 +151,4 @@ public:
void updateAverageTrgBatchWords(size_t trgBatchWords);
};
} // namespace marian
} // namespace marian