mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Cherry picked cleaning/refeactoring patches (#905)
Cherry-picked updates from pull request #457 Co-authored-by: Mateusz Chudyk <mateuszchudyk@gmail.com>
This commit is contained in:
parent
71b5454b9e
commit
07c39c7d76
@ -113,10 +113,10 @@ std::string CLIWrapper::switchGroup(std::string name) {
|
||||
return name;
|
||||
}
|
||||
|
||||
void CLIWrapper::parse(int argc, char **argv) {
|
||||
void CLIWrapper::parse(int argc, char** argv) {
|
||||
try {
|
||||
app_->parse(argc, argv);
|
||||
} catch(const CLI::ParseError &e) {
|
||||
} catch(const CLI::ParseError& e) {
|
||||
exit(app_->exit(e));
|
||||
}
|
||||
|
||||
@ -182,6 +182,13 @@ void CLIWrapper::parseAliases() {
|
||||
}
|
||||
}
|
||||
|
||||
std::string CLIWrapper::keyName(const std::string& args) const {
|
||||
// re-use existing functions from CLI11 to keep option names consistent
|
||||
return std::get<1>(
|
||||
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
|
||||
.front(); // get first long name
|
||||
}
|
||||
|
||||
void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
|
||||
auto cmdOptions = getParsedOptionNames();
|
||||
// Keep track of unrecognized options from the provided config
|
||||
@ -276,7 +283,7 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
|
||||
for(auto const &it : options_)
|
||||
keys.push_back(it.first);
|
||||
// sort option names by creation index
|
||||
sort(keys.begin(), keys.end(), [this](const std::string &a, const std::string &b) {
|
||||
sort(keys.begin(), keys.end(), [this](const std::string& a, const std::string& b) {
|
||||
return options_.at(a).idx < options_.at(b).idx;
|
||||
});
|
||||
return keys;
|
||||
|
@ -44,7 +44,7 @@ struct CLIAliasTuple {
|
||||
class CLIFormatter : public CLI::Formatter {
|
||||
public:
|
||||
CLIFormatter(size_t columnWidth, size_t screenWidth);
|
||||
virtual std::string make_option_desc(const CLI::Option *) const override;
|
||||
virtual std::string make_option_desc(const CLI::Option*) const override;
|
||||
|
||||
private:
|
||||
size_t screenWidth_{0};
|
||||
@ -85,12 +85,7 @@ private:
|
||||
|
||||
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
|
||||
// '--help,-h'
|
||||
std::string keyName(const std::string &args) const {
|
||||
// re-use existing functions from CLI11 to keep option names consistent
|
||||
return std::get<1>(
|
||||
CLI::detail::get_names(CLI::detail::split_names(args))) // get long names only
|
||||
.front(); // get first long name
|
||||
}
|
||||
std::string keyName(const std::string &args) const;
|
||||
|
||||
// Get names of options passed via command-line
|
||||
std::unordered_set<std::string> getParsedOptionNames() const;
|
||||
@ -134,7 +129,7 @@ public:
|
||||
* @return Option object
|
||||
*/
|
||||
template <typename T>
|
||||
CLI::Option *add(const std::string &args, const std::string &help, T val) {
|
||||
CLI::Option* add(const std::string& args, const std::string& help, T val) {
|
||||
return addOption<T>(keyName(args),
|
||||
args,
|
||||
help,
|
||||
@ -159,7 +154,7 @@ public:
|
||||
* @TODO: require to always state the default value creating the parser as this will be clearer
|
||||
*/
|
||||
template <typename T>
|
||||
CLI::Option *add(const std::string &args, const std::string &help) {
|
||||
CLI::Option* add(const std::string& args, const std::string& help) {
|
||||
return addOption<T>(keyName(args),
|
||||
args,
|
||||
help,
|
||||
@ -206,7 +201,7 @@ public:
|
||||
std::string switchGroup(std::string name = "");
|
||||
|
||||
// Parse command-line arguments. Handles --help and --version options
|
||||
void parse(int argc, char **argv);
|
||||
void parse(int argc, char** argv);
|
||||
|
||||
/**
|
||||
* @brief Expand aliases based on arguments parsed with parse(int, char**)
|
||||
@ -240,11 +235,12 @@ public:
|
||||
std::string dumpConfig(bool skipUnmodified = false) const;
|
||||
|
||||
private:
|
||||
template <typename T,
|
||||
// options with numeric and string-like values
|
||||
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
|
||||
CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
template <typename T>
|
||||
using EnableIfNumbericOrString = CLI::enable_if_t<!CLI::is_bool<T>::value
|
||||
&& !CLI::is_vector<T>::value, CLI::detail::enabler>;
|
||||
|
||||
template <typename T, EnableIfNumbericOrString<T> = CLI::detail::dummy>
|
||||
CLI::Option* addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
@ -261,7 +257,7 @@ private:
|
||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||
// get variable associated with the option
|
||||
auto &var = options_[key].var->as<T>();
|
||||
auto& var = options_[key].var->as<T>();
|
||||
// store parser result in var
|
||||
auto ret = CLI::detail::lexical_cast(res[0], var);
|
||||
// update YAML entry
|
||||
@ -288,10 +284,11 @@ private:
|
||||
return options_[key].opt;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
// options with vector values
|
||||
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
template <typename T>
|
||||
using EnableIfVector = CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler>;
|
||||
|
||||
template <typename T, EnableIfVector<T> = CLI::detail::dummy>
|
||||
CLI::Option* addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
@ -308,7 +305,7 @@ private:
|
||||
CLI::callback_t fun = [this, key](CLI::results_t res) {
|
||||
options_[key].priority = cli::OptionPriority::CommandLine;
|
||||
// get vector variable associated with the option
|
||||
auto &vec = options_[key].var->as<T>();
|
||||
auto& vec = options_[key].var->as<T>();
|
||||
vec.clear();
|
||||
bool ret = true;
|
||||
// handle '[]' as an empty vector
|
||||
@ -316,7 +313,7 @@ private:
|
||||
ret = true;
|
||||
} else {
|
||||
// populate the vector with parser results
|
||||
for(const auto &a : res) {
|
||||
for(const auto& a : res) {
|
||||
vec.emplace_back();
|
||||
ret &= CLI::detail::lexical_cast(a, vec.back());
|
||||
}
|
||||
@ -345,10 +342,11 @@ private:
|
||||
return options_[key].opt;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
// options with boolean values, called flags in CLI11
|
||||
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
|
||||
CLI::Option *addOption(const std::string &key,
|
||||
template <typename T>
|
||||
using EnableIfBoolean = CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler>;
|
||||
|
||||
template <typename T, EnableIfBoolean<T> = CLI::detail::dummy>
|
||||
CLI::Option* addOption(const std::string &key,
|
||||
const std::string &args,
|
||||
const std::string &help,
|
||||
T val,
|
||||
|
@ -107,7 +107,7 @@ private:
|
||||
* @param mode change the set of available command-line options, e.g. training, translation, etc.
|
||||
* @param validate validate parsed options and abort on failure
|
||||
*
|
||||
* @return parsed otions
|
||||
* @return parsed options
|
||||
*/
|
||||
Ptr<Options> parseOptions(int argc,
|
||||
char** argv,
|
||||
|
@ -942,10 +942,10 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
||||
cli.add<std::string>("--ulr-query-vectors",
|
||||
"Path to file with universal sources embeddings from projection into universal space",
|
||||
"");
|
||||
// keys: EK in Fig2 : is the keys of the target embbedings projected to unified space (i.e. ENU in
|
||||
// keys: EK in Fig2 : is the keys of the target embeddings projected to unified space (i.e. ENU in
|
||||
// multi-lingual case)
|
||||
cli.add<std::string>("--ulr-keys-vectors",
|
||||
"Path to file with universal sources embeddings of traget keys from projection into universal space",
|
||||
"Path to file with universal sources embeddings of target keys from projection into universal space",
|
||||
"");
|
||||
cli.add<bool>("--ulr-trainable-transformation",
|
||||
"Make Query Transformation Matrix A trainable");
|
||||
|
@ -10,7 +10,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
|
||||
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is necessary, remove if no problems occur.
|
||||
#define NodeOp(op) [=]() { op; }
|
||||
|
||||
// helper macro to disable optimization (gcc only)
|
||||
|
@ -136,7 +136,7 @@ static void setErrorHandlers() {
|
||||
|
||||
// modify the log pattern for the "general" logger to include the MPI rank
|
||||
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
|
||||
void switchtoMultinodeLogging(std::string nodeIdStr) {
|
||||
void switchToMultinodeLogging(std::string nodeIdStr) {
|
||||
Logger log = spdlog::get("general");
|
||||
if(log)
|
||||
log->set_pattern(fmt::format("[%Y-%m-%d %T mpi:{}] %v", nodeIdStr));
|
||||
|
@ -178,4 +178,4 @@ void checkedLog(std::string logger, std::string level, Args... args) {
|
||||
}
|
||||
|
||||
void createLoggers(const marian::Config* options = nullptr);
|
||||
void switchtoMultinodeLogging(std::string nodeIdStr);
|
||||
void switchToMultinodeLogging(std::string nodeIdStr);
|
||||
|
@ -98,7 +98,7 @@ public:
|
||||
* @brief Splice options from a YAML node
|
||||
*
|
||||
* By default, only options with keys that do not already exist in options_ are extracted from
|
||||
* node. These options are cloned if overwirte is true.
|
||||
* node. These options are cloned if overwrite is true.
|
||||
*
|
||||
* @param node a YAML node to transfer the options from
|
||||
* @param overwrite overwrite all options
|
||||
|
@ -379,7 +379,7 @@ public:
|
||||
* @see marian::data::SubBatch::split(size_t n)
|
||||
*/
|
||||
std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
|
||||
ABORT_IF(size() == 0, "Encoutered batch size of 0");
|
||||
ABORT_IF(size() == 0, "Encountered batch size of 0");
|
||||
|
||||
std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
|
||||
// split each stream separately
|
||||
|
@ -44,7 +44,7 @@ public:
|
||||
virtual void prepare() {}
|
||||
virtual void restore(Ptr<TrainingState>) {}
|
||||
|
||||
// @TODO: remove after cleaning traininig/training.h
|
||||
// @TODO: remove after cleaning training/training.h
|
||||
virtual Ptr<Options> options() { return options_; }
|
||||
};
|
||||
|
||||
|
@ -10,7 +10,7 @@ namespace marian {
|
||||
class IVocab;
|
||||
|
||||
// Wrapper around vocabulary types. Can choose underlying
|
||||
// vocabulary implementation (vImpl_) based on speficied path
|
||||
// vocabulary implementation (vImpl_) based on specified path
|
||||
// and suffix.
|
||||
// Vocabulary implementations can currently be:
|
||||
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)
|
||||
|
@ -70,9 +70,9 @@ public:
|
||||
outputs.push_back(output2);
|
||||
}
|
||||
|
||||
auto concated = concatenate(outputs, -1);
|
||||
auto concatenated = concatenate(outputs, -1);
|
||||
|
||||
return concated;
|
||||
return concatenated;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -67,7 +67,7 @@ public:
|
||||
return count_->val()->scalar<T>();
|
||||
}
|
||||
|
||||
// @TODO: add a funtion for returning maybe ratio?
|
||||
// @TODO: add a function for returning maybe ratio?
|
||||
|
||||
size_t size() const {
|
||||
ABORT_IF(!count_, "Labels have not been defined");
|
||||
@ -189,7 +189,7 @@ public:
|
||||
*
|
||||
* L = sum_i^N L_i + N/M sum_j^M L_j
|
||||
*
|
||||
* We set labels to N. When reporting L/N this is equvalient to sum of means.
|
||||
* We set labels to N. When reporting L/N this is equivalent to sum of means.
|
||||
* Compare to sum of means below where N is factored into the loss, but labels
|
||||
* are set to 1.
|
||||
*/
|
||||
|
@ -76,7 +76,7 @@ private:
|
||||
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
|
||||
|
||||
// @TODO: switch to new random generator back-end.
|
||||
// This is rarly used however.
|
||||
// This is rarely used however.
|
||||
std::random_device rd;
|
||||
std::mt19937 engine(rd());
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace marian {
|
||||
|
||||
/**
|
||||
* This file contains nearly all BERT-related code and adds BERT-funtionality
|
||||
* This file contains nearly all BERT-related code and adds BERT-functionality
|
||||
* on top of existing classes like TansformerEncoder and Classifier.
|
||||
*/
|
||||
|
||||
@ -82,7 +82,7 @@ public:
|
||||
// Initialize to sample random vocab id
|
||||
randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
|
||||
|
||||
// Intialize to sample random percentage
|
||||
// Initialize to sample random percentage
|
||||
randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
|
||||
|
||||
auto& words = subBatch->data();
|
||||
|
@ -14,7 +14,7 @@ namespace marian {
|
||||
* Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
|
||||
* Already has support for multi-objective training.
|
||||
*
|
||||
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
|
||||
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for decoder/classifier
|
||||
* multi-objective training.
|
||||
*/
|
||||
class EncoderClassifierBase : public models::IModel {
|
||||
|
@ -220,7 +220,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
|
||||
if(clearGraph)
|
||||
clear(graph);
|
||||
|
||||
// Required first step, also intializes shortlist
|
||||
// Required first step, also initializes shortlist
|
||||
auto state = startState(graph, batch);
|
||||
|
||||
// Fill state with embeddings from batch (ground truth)
|
||||
|
@ -70,7 +70,7 @@ public:
|
||||
// Hack for translating with length longer than trained embeddings
|
||||
// We check if the embedding matrix "Wpos" already exist so we can
|
||||
// check the number of positions in that loaded parameter.
|
||||
// We then have to restict the maximum length to the maximum positon
|
||||
// We then have to restrict the maximum length to the maximum positon
|
||||
// and positions beyond this will be the maximum position.
|
||||
Expr seenEmb = graph_->get("Wpos");
|
||||
int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;
|
||||
|
@ -101,7 +101,7 @@ public:
|
||||
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
|
||||
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
|
||||
rankStr.insert(rankStr.begin(), ' ');
|
||||
switchtoMultinodeLogging(rankStr);
|
||||
switchToMultinodeLogging(rankStr);
|
||||
}
|
||||
|
||||
// log hostnames in order, and test
|
||||
@ -261,7 +261,7 @@ void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
|
||||
ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
|
||||
mpi = nullptr; // destruct caller's handle
|
||||
ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
|
||||
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we sucessfully completed computation
|
||||
if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
|
||||
ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
|
||||
s_mpi->finalize(); // signal successful completion to MPI
|
||||
s_mpi = nullptr; // release the singleton instance upon last finalization
|
||||
|
@ -132,7 +132,7 @@ public:
|
||||
/**
|
||||
* Determine maximal batch size that can fit into the given workspace
|
||||
* so that reallocation does not happen. Rather adjust the batch size
|
||||
* based on the stastistics collected here. Activated with
|
||||
* based on the statistics collected here. Activated with
|
||||
* `--mini-batch-fit`.
|
||||
* In a multi-GPU scenario, the first GPU is used to determine the size.
|
||||
* The actual allowed size is then determined by multiplying it with the
|
||||
|
Loading…
Reference in New Issue
Block a user