Cherry picked cleaning/refeactoring patches (#905)

Cherry-picked updates from pull request #457

Co-authored-by: Mateusz Chudyk <mateuszchudyk@gmail.com>
This commit is contained in:
Roman Grundkiewicz 2022-01-28 14:16:41 +00:00 committed by GitHub
parent 71b5454b9e
commit 07c39c7d76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 113 additions and 108 deletions

View File

@ -113,10 +113,10 @@ std::string CLIWrapper::switchGroup(std::string name) {
return name;
}
void CLIWrapper::parse(int argc, char **argv) {
void CLIWrapper::parse(int argc, char** argv) {
try {
app_->parse(argc, argv);
} catch(const CLI::ParseError &e) {
} catch(const CLI::ParseError& e) {
exit(app_->exit(e));
}
@ -182,6 +182,13 @@ void CLIWrapper::parseAliases() {
}
}
std::string CLIWrapper::keyName(const std::string& args) const {
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
auto cmdOptions = getParsedOptionNames();
// Keep track of unrecognized options from the provided config
@ -276,7 +283,7 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
for(auto const &it : options_)
keys.push_back(it.first);
// sort option names by creation index
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) {
sort(keys.begin(), keys.end(), [this](const std::string& a, const std::string& b) {
return options_.at(a).idx < options_.at(b).idx;
});
return keys;

View File

@ -44,7 +44,7 @@ struct CLIAliasTuple {
class CLIFormatter : public CLI::Formatter {
public:
CLIFormatter(size_t columnWidth, size_t screenWidth);
virtual std::string make_option_desc(const CLI::Option *) const override;
virtual std::string make_option_desc(const CLI::Option*) const override;
private:
size_t screenWidth_{0};
@ -85,12 +85,7 @@ private:
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
// '--help,-h'
std::string keyName(const std::string &args) const {
// re-use existing functions from CLI11 to keep option names consistent
return std::get<1>(
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
.front(); // get first long name
}
std::string keyName(const std::string &args) const;
// Get names of options passed via command-line
std::unordered_set<std::string> getParsedOptionNames() const;
@ -134,7 +129,7 @@ public:
* @return Option object
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help, T val) {
CLI::Option* add(const std::string& args, const std::string& help, T val) {
return addOption<T>(keyName(args),
args,
help,
@ -159,7 +154,7 @@ public:
* @TODO: require to always state the default value creating the parser as this will be clearer
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help) {
CLI::Option* add(const std::string& args, const std::string& help) {
return addOption<T>(keyName(args),
args,
help,
@ -206,7 +201,7 @@ public:
std::string switchGroup(std::string name = "");
// Parse command-line arguments. Handles --help and --version options
void parse(int argc, char **argv);
void parse(int argc, char** argv);
/**
* @brief Expand aliases based on arguments parsed with parse(int, char**)
@ -240,11 +235,12 @@ public:
std::string dumpConfig(bool skipUnmodified = false) const;
private:
template <typename T,
// options with numeric and string-like values
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfNumbericOrString = CLI::enable_if_t<!CLI::is_bool<T>::value
&& !CLI::is_vector<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfNumbericOrString<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,
@ -261,7 +257,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine;
// get variable associated with the option
auto &var = options_[key].var->as<T>();
auto& var = options_[key].var->as<T>();
// store parser result in var
auto ret = CLI::detail::lexical_cast(res[0], var);
// update YAML entry
@ -288,10 +284,11 @@ private:
return options_[key].opt;
}
template <typename T,
// options with vector values
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfVector = CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfVector<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,
@ -308,7 +305,7 @@ private:
CLI::callback_t fun = [this, key](CLI::results_t res) {
options_[key].priority = cli::OptionPriority::CommandLine;
// get vector variable associated with the option
auto &vec = options_[key].var->as<T>();
auto& vec = options_[key].var->as<T>();
vec.clear();
bool ret = true;
// handle '[]' as an empty vector
@ -316,7 +313,7 @@ private:
ret = true;
} else {
// populate the vector with parser results
for(const auto &a : res) {
for(const auto& a : res) {
vec.emplace_back();
ret &= CLI::detail::lexical_cast(a, vec.back());
}
@ -345,10 +342,11 @@ private:
return options_[key].opt;
}
template <typename T,
// options with boolean values, called flags in CLI11
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
CLI::Option *addOption(const std::string &key,
template <typename T>
using EnableIfBoolean = CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler>;
template <typename T, EnableIfBoolean<T> = CLI::detail::dummy>
CLI::Option* addOption(const std::string &key,
const std::string &args,
const std::string &help,
T val,

View File

@ -107,7 +107,7 @@ private:
* @param mode change the set of available command-line options, e.g. training, translation, etc.
* @param validate validate parsed options and abort on failure
*
* @return parsed otions
* @return parsed options
*/
Ptr<Options> parseOptions(int argc,
char** argv,

View File

@ -942,10 +942,10 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
cli.add<std::string>("--ulr-query-vectors",
"Path to file with universal sources embeddings from projection into universal space",
"");
// keys: EK in Fig2 : is the keys of the target embbedings projected to unified space (i.e. ENU in
// keys: EK in Fig2 : is the keys of the target embeddings projected to unified space (i.e. ENU in
// multi-lingual case)
cli.add<std::string>("--ulr-keys-vectors",
"Path to file with universal sources embeddings of traget keys from projection into universal space",
"Path to file with universal sources embeddings of target keys from projection into universal space",
"");
cli.add<bool>("--ulr-trainable-transformation",
"Make Query Transformation Matrix A trainable");

View File

@ -10,7 +10,7 @@
#include <string>
#include <vector>
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is necessary, remove if no problems occur.
#define NodeOp(op) [=]() { op; }
// helper macro to disable optimization (gcc only)

View File

@ -136,7 +136,7 @@ static void setErrorHandlers() {
// modify the log pattern for the "general" logger to include the MPI rank
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
void switchtoMultinodeLogging(std::string nodeIdStr) {
void switchToMultinodeLogging(std::string nodeIdStr) {
Logger log = spdlog::get("general");
if(log)
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));

View File

@ -178,4 +178,4 @@ void checkedLog(std::string logger, std::string level, Args... args) {
}
void createLoggers(const marian::Config* options = nullptr);
void switchtoMultinodeLogging(std::string nodeIdStr);
void switchToMultinodeLogging(std::string nodeIdStr);

View File

@ -98,7 +98,7 @@ public:
* @brief Splice options from a YAML node
*
* By default, only options with keys that do not already exist in options_ are extracted from
* node. These options are cloned if overwirte is true.
* node. These options are cloned if overwrite is true.
*
* @param node a YAML node to transfer the options from
* @param overwrite overwrite all options

View File

@ -379,7 +379,7 @@ public:
* @see marian::data::SubBatch::split(size_t n)
*/
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
ABORT_IF(size() == 0, "Encoutered batch size of 0");
ABORT_IF(size() == 0, "Encountered batch size of 0");
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
// split each stream separately

View File

@ -44,7 +44,7 @@ public:
virtual void prepare() {}
virtual void restore(Ptr<TrainingState>) {}
// @TODO: remove after cleaning traininig/training.h
// @TODO: remove after cleaning training/training.h
virtual Ptr<Options> options() { return options_; }
};

View File

@ -10,7 +10,7 @@ namespace marian {
class IVocab;
// Wrapper around vocabulary types. Can choose underlying
// vocabulary implementation (vImpl_) based on speficied path
// vocabulary implementation (vImpl_) based on specified path
// and suffix.
// Vocabulary implementations can currently be:
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)

View File

@ -70,9 +70,9 @@ public:
outputs.push_back(output2);
}
auto concated = concatenate(outputs, -1);
auto concatenated = concatenate(outputs, -1);
return concated;
return concatenated;
}
protected:

View File

@ -67,7 +67,7 @@ public:
return count_->val()->scalar<T>();
}
// @TODO: add a funtion for returning maybe ratio?
// @TODO: add a function for returning maybe ratio?
size_t size() const {
ABORT_IF(!count_, "Labels have not been defined");
@ -189,7 +189,7 @@ public:
*
* L = sum_i^N L_i + N/M sum_j^M L_j
*
* We set labels to N. When reporting L/N this is equvalient to sum of means.
* We set labels to N. When reporting L/N this is equivalent to sum of means.
* Compare to sum of means below where N is factored into the loss, but labels
* are set to 1.
*/

View File

@ -76,7 +76,7 @@ private:
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
// @TODO: switch to new random generator back-end.
// This is rarly used however.
// This is rarely used however.
std::random_device rd;
std::mt19937 engine(rd());

View File

@ -8,7 +8,7 @@
namespace marian {
/**
* This file contains nearly all BERT-related code and adds BERT-funtionality
* This file contains nearly all BERT-related code and adds BERT-functionality
* on top of existing classes like TansformerEncoder and Classifier.
*/
@ -82,7 +82,7 @@ public:
// Initialize to sample random vocab id
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
// Intialize to sample random percentage
// Initialize to sample random percentage
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
auto& words = subBatch->data();

View File

@ -14,7 +14,7 @@ namespace marian {
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
* Already has support for multi-objective training.
*
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for decoder/classifier
* multi-objective training.
*/
class EncoderClassifierBase : public models::IModel {

View File

@ -220,7 +220,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
if(clearGraph)
clear(graph);
// Required first step, also intializes shortlist
// Required first step, also initializes shortlist
auto state = startState(graph, batch);
// Fill state with embeddings from batch (ground truth)

View File

@ -70,7 +70,7 @@ public:
// Hack for translating with length longer than trained embeddings
// We check if the embedding matrix "Wpos" already exist so we can
// check the number of positions in that loaded parameter.
// We then have to restict the maximum length to the maximum positon
// We then have to restrict the maximum length to the maximum positon
// and positions beyond this will be the maximum position.
Expr seenEmb = graph_->get("Wpos");
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;

View File

@ -101,7 +101,7 @@ public:
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
rankStr.insert(rankStr.begin(), ' ');
switchtoMultinodeLogging(rankStr);
switchToMultinodeLogging(rankStr);
}
// log hostnames in order, and test
@ -261,7 +261,7 @@ void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
mpi = nullptr; // destruct caller's handle
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
s_mpi->finalize(); // signal successful completion to MPI
s_mpi = nullptr; // release the singleton instance upon last finalization

View File

@ -132,7 +132,7 @@ public:
/**
* Determine maximal batch size that can fit into the given workspace
* so that reallocation does not happen. Rather adjust the batch size
* based on the stastistics collected here. Activated with
* based on the statistics collected here. Activated with
* `--mini-batch-fit`.
* In a multi-GPU scenario, the first GPU is used to determine the size.
* The actual allowed size is then determined by multiplying it with the