mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
merge with public master
This commit is contained in:
commit
1174cecbd6
@ -224,7 +224,14 @@ if(CUDA_FOUND)
|
||||
endif()
|
||||
|
||||
else(CUDA_FOUND)
|
||||
message(FATAL_ERROR "CUDA has not been found, set -DCOMPILE_CUDA=off to avoid this check and to compile the CPU version only")
|
||||
message("
|
||||
Cannot find suitable CUDA libraries. Specify the path explicitly with
|
||||
-DCUDA_TOOLKIT_ROOT_DIR=/path/to/appropriate/cuda/installation
|
||||
(hint: try /usr/local/$(readlink /usr/local/cuda))
|
||||
OR compile the CPU-only version of Marian with
|
||||
-DCOMPILE_CUDA=off
|
||||
")
|
||||
message(FATAL_ERROR "FATAL ERROR: No suitable CUDA library found.")
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
else(COMPILE_CUDA)
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 08c1485c2c925e4e1ed8a3e5df686ed5c7dc496e
|
||||
Subproject commit 95b66e74d47107a2e3abad5a4c5338904129a25a
|
@ -17,6 +17,7 @@ add_library(marian STATIC
|
||||
common/config_parser.cpp
|
||||
common/aliases.cpp
|
||||
common/config_validator.cpp
|
||||
common/options.cpp
|
||||
common/binary.cpp
|
||||
common/io.cpp
|
||||
common/filesystem.cpp
|
||||
|
@ -110,8 +110,12 @@ CLIWrapper::CLIWrapper(Ptr<marian::Options> options,
|
||||
|
||||
CLIWrapper::~CLIWrapper() {}
|
||||
|
||||
void CLIWrapper::switchGroup(const std::string &name) {
|
||||
currentGroup_ = name.empty() ? defaultGroup_ : name;
|
||||
// set current group to name, return previous group
|
||||
std::string CLIWrapper::switchGroup(std::string name) {
|
||||
currentGroup_.swap(name);
|
||||
if (currentGroup_.empty())
|
||||
currentGroup_ = defaultGroup_;
|
||||
return name;
|
||||
}
|
||||
|
||||
void CLIWrapper::parse(int argc, char **argv) {
|
||||
|
@ -212,8 +212,9 @@ public:
|
||||
* Switch to different option group or to the default group if argument is empty.
|
||||
*
|
||||
* @param name Header of the option group
|
||||
* @return Previous group.
|
||||
*/
|
||||
void switchGroup(const std::string &name = "");
|
||||
std::string switchGroup(std::string name = "");
|
||||
|
||||
// Parse command-line arguments. Handles --help and --version options
|
||||
void parse(int argc, char **argv);
|
||||
|
@ -1,9 +1,11 @@
|
||||
#include "common/config.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/options.h"
|
||||
#include "common/regex.h"
|
||||
#include "common/utils.h"
|
||||
#include "common/version.h"
|
||||
#include "common/regex.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
@ -14,35 +16,26 @@ namespace marian {
|
||||
// @TODO: keep seed in a single place, now it is kept here and in Config/Options
|
||||
size_t Config::seed = (size_t)time(0);
|
||||
|
||||
Config::Config(int argc,
|
||||
char** argv,
|
||||
cli::mode mode,
|
||||
bool validate /*= true*/) {
|
||||
initialize(argc, argv, mode, validate);
|
||||
|
||||
Config::Config(ConfigParser const& cp) {
|
||||
initialize(cp);
|
||||
}
|
||||
|
||||
Config::Config(int argc, char** argv, cli::mode mode, bool validate /*= true*/)
|
||||
: Config(ConfigParser(argc, argv, mode, validate)) {}
|
||||
|
||||
Config::Config(const Config& other) : config_(YAML::Clone(other.config_)) {}
|
||||
Config::Config(const Options& options) : config_(YAML::Clone(options.getYaml())) {}
|
||||
|
||||
void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
|
||||
auto parser = ConfigParser(argc, argv, mode, validate);
|
||||
config_ = parser.getConfig();
|
||||
void Config::initialize(ConfigParser const& cp) {
|
||||
config_ = YAML::Clone(cp.getConfig());
|
||||
cli::mode mode = cp.getMode();
|
||||
|
||||
createLoggers(this);
|
||||
|
||||
// echo version and command line
|
||||
LOG(info, "[marian] Marian {}", buildVersion());
|
||||
std::string cmdLine;
|
||||
for(int i = 0; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
std::string quote; // attempt to quote special chars
|
||||
if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
|
||||
quote = "'";
|
||||
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
|
||||
if(!cmdLine.empty())
|
||||
cmdLine.push_back(' ');
|
||||
cmdLine += quote + arg + quote;
|
||||
}
|
||||
std::string cmdLine = cp.cmdLine();
|
||||
std::string hostname; int pid; std::tie
|
||||
(hostname, pid) = utils::hostnameAndProcessId();
|
||||
LOG(info, "[marian] Running on {} as process {} with command line:", hostname, pid);
|
||||
@ -262,14 +255,17 @@ std::vector<DeviceId> Config::getDevices(Ptr<Options> options,
|
||||
return devices;
|
||||
}
|
||||
|
||||
Ptr<Options> parseOptions(int argc,
|
||||
char** argv,
|
||||
cli::mode mode,
|
||||
bool validate /*= true*/) {
|
||||
auto config = New<Config>(argc, argv, mode, validate);
|
||||
auto options = New<Options>();
|
||||
options->merge(config->get());
|
||||
return options;
|
||||
Ptr<Options>
|
||||
parseOptions(int argc, char** argv, cli::mode mode, bool validate){
|
||||
ConfigParser cp(mode);
|
||||
return cp.parseOptions(argc, argv, validate);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Config& config) {
|
||||
YAML::Emitter outYaml;
|
||||
cli::OutputYaml(config.get(), outYaml);
|
||||
out << outYaml.c_str();
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -38,6 +38,7 @@ public:
|
||||
|
||||
typedef YAML::Node YamlNode;
|
||||
|
||||
Config(ConfigParser const& cp);
|
||||
// TODO: remove mode from this class
|
||||
Config(int argc,
|
||||
char** argv,
|
||||
@ -47,7 +48,7 @@ public:
|
||||
Config(const Config& other);
|
||||
Config(const Options& options);
|
||||
|
||||
void initialize(int argc, char** argv, cli::mode mode, bool validate);
|
||||
void initialize(ConfigParser const& cp);
|
||||
|
||||
bool has(const std::string& key) const;
|
||||
|
||||
@ -83,12 +84,7 @@ public:
|
||||
|
||||
void save(const std::string& name);
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, const Config& config) {
|
||||
YAML::Emitter outYaml;
|
||||
cli::OutputYaml(config.get(), outYaml);
|
||||
out << outYaml.c_str();
|
||||
return out;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& out, const Config& config);
|
||||
|
||||
static std::vector<DeviceId> getDevices(Ptr<Options> options,
|
||||
size_t myMPIRank = 0,
|
||||
|
@ -1,12 +1,13 @@
|
||||
#include "common/config_parser.h"
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "common/cli_helper.h"
|
||||
#include "common/config.h"
|
||||
#include "common/config_parser.h"
|
||||
#include "common/config_validator.h"
|
||||
#include "common/definitions.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/options.h"
|
||||
#include "common/regex.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
@ -49,6 +50,56 @@ const std::set<std::string> PATHS = {
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
std::string escapeCmdLine(int argc, char** argv){
|
||||
std::string cmdLine;
|
||||
for(int i = 0; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
std::string quote; // attempt to quote special chars
|
||||
if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
|
||||
quote = "'";
|
||||
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
|
||||
if(!cmdLine.empty())
|
||||
cmdLine.push_back(' ');
|
||||
cmdLine += quote + arg + quote;
|
||||
}
|
||||
return cmdLine;
|
||||
}
|
||||
|
||||
std::string const& ConfigParser::cmdLine() const {
|
||||
return cmdLine_;
|
||||
}
|
||||
|
||||
ConfigParser::ConfigParser(cli::mode mode)
|
||||
: cli_(config_,"Marian: Fast Neural Machine Translation in C++",
|
||||
"General options", "", 40),
|
||||
mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
|
||||
|
||||
addOptionsGeneral(cli_);
|
||||
if (mode == cli::mode::server)
|
||||
addOptionsServer(cli_);
|
||||
addOptionsModel(cli_);
|
||||
|
||||
// clang-format off
|
||||
switch(mode_) {
|
||||
case cli::mode::training:
|
||||
addOptionsTraining(cli_);
|
||||
addOptionsValidation(cli_);
|
||||
break;
|
||||
case cli::mode::translation:
|
||||
addOptionsTranslation(cli_);
|
||||
break;
|
||||
case cli::mode::scoring:
|
||||
addOptionsScoring(cli_);
|
||||
break;
|
||||
default:
|
||||
ABORT("wrong CLI mode");
|
||||
break;
|
||||
}
|
||||
|
||||
addAliases(cli_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
|
||||
int defaultWorkspace = (mode_ == cli::mode::translation) ? 512 : 2048;
|
||||
|
||||
@ -87,14 +138,16 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
|
||||
|
||||
void ConfigParser::addOptionsServer(cli::CLIWrapper& cli) {
|
||||
// clang-format off
|
||||
auto previous_group = cli.switchGroup("Server options");
|
||||
cli.add<size_t>("--port,-p",
|
||||
"Port number for web socket server",
|
||||
8080);
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Model options");
|
||||
auto previous_group = cli.switchGroup("Model options");
|
||||
|
||||
// clang-format off
|
||||
if(mode_ == cli::mode::translation) {
|
||||
@ -259,11 +312,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
cli.add<float>("--transformer-dropout-ffn",
|
||||
"Dropout for transformer filter (0 = no dropout)");
|
||||
}
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Training options");
|
||||
auto previous_group = cli.switchGroup("Training options");
|
||||
// clang-format off
|
||||
cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
|
||||
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
|
||||
@ -447,11 +501,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
|
||||
cli.add<std::vector<std::string>>("--task",
|
||||
"Use predefined set of options. Possible values: transformer, transformer-big");
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Validation set options");
|
||||
auto previous_group = cli.switchGroup("Validation set options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<std::vector<std::string>>("--valid-sets",
|
||||
@ -508,11 +563,12 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
||||
"Keep best model for each validation metric");
|
||||
cli.add<std::string>("--valid-log",
|
||||
"Log validation scores to file given by arg");
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Translator options");
|
||||
auto previous_group = cli.switchGroup("Translator options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<std::vector<std::string>>("--input,-i",
|
||||
@ -572,11 +628,12 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
// add ULR settings
|
||||
addSuboptionsULR(cli);
|
||||
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
||||
cli.switchGroup("Scorer options");
|
||||
auto previous_group = cli.switchGroup("Scorer options");
|
||||
|
||||
// clang-format off
|
||||
cli.add<bool>("--no-reload",
|
||||
@ -610,15 +667,13 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
||||
|
||||
cli.add<bool>("--optimize",
|
||||
"Optimize speed aggressively sacrificing memory or precision");
|
||||
|
||||
//@TODO: gemm-type missing? fix this
|
||||
|
||||
cli.add<bool>("--fp16",
|
||||
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
|
||||
cli.add<std::vector<std::string>>("--precision",
|
||||
"Mixed precision for inference, set parameter type in expression graph",
|
||||
{"float32"});
|
||||
|
||||
cli.switchGroup(previous_group);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -747,46 +802,20 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
||||
cli::CLIWrapper cli(config_,
|
||||
"Marian: Fast Neural Machine Translation in C++",
|
||||
"General options",
|
||||
"",
|
||||
40);
|
||||
|
||||
addOptionsGeneral(cli);
|
||||
if(modeServer_)
|
||||
addOptionsServer(cli);
|
||||
addOptionsModel(cli);
|
||||
cli::mode ConfigParser::getMode() const { return mode_; }
|
||||
|
||||
// clang-format off
|
||||
switch(mode_) {
|
||||
case cli::mode::training:
|
||||
addOptionsTraining(cli);
|
||||
addOptionsValidation(cli);
|
||||
break;
|
||||
case cli::mode::translation:
|
||||
addOptionsTranslation(cli);
|
||||
break;
|
||||
case cli::mode::scoring:
|
||||
addOptionsScoring(cli);
|
||||
break;
|
||||
default:
|
||||
ABORT("wrong CLI mode");
|
||||
break;
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
addAliases(cli);
|
||||
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
|
||||
cmdLine_ = escapeCmdLine(argc,argv);
|
||||
|
||||
// parse command-line options and fill wrapped YAML config
|
||||
cli.parse(argc, argv);
|
||||
cli_.parse(argc, argv);
|
||||
|
||||
// get paths to extra config files
|
||||
auto configPaths = findConfigPaths();
|
||||
if(!configPaths.empty()) {
|
||||
auto config = loadConfigFiles(configPaths);
|
||||
cli.updateConfig(config,
|
||||
cli_.updateConfig(config,
|
||||
cli::OptionPriority::ConfigFile,
|
||||
"There are option(s) in a config file that are not expected");
|
||||
}
|
||||
@ -807,15 +836,18 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
|
||||
config_.remove("dump-config");
|
||||
|
||||
if(dumpMode == "expand") {
|
||||
cli.parseAliases();
|
||||
cli_.parseAliases();
|
||||
}
|
||||
|
||||
bool minimal = (dumpMode == "minimal" || dumpMode == "expand");
|
||||
std::cout << cli.dumpConfig(minimal) << std::endl;
|
||||
std::cout << cli_.dumpConfig(minimal) << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
cli.parseAliases();
|
||||
cli_.parseAliases();
|
||||
auto opts = New<Options>();
|
||||
opts->merge(Config(*this).get());
|
||||
return opts;
|
||||
}
|
||||
|
||||
std::vector<std::string> ConfigParser::findConfigPaths() {
|
||||
@ -877,7 +909,7 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
|
||||
return configAll;
|
||||
}
|
||||
|
||||
YAML::Node ConfigParser::getConfig() const {
|
||||
const YAML::Node& ConfigParser::getConfig() const {
|
||||
return config_;
|
||||
}
|
||||
} // namespace marian
|
||||
|
@ -14,7 +14,6 @@
|
||||
namespace marian {
|
||||
|
||||
namespace cli {
|
||||
// CLI mode
|
||||
enum struct mode { training, translation, scoring, server };
|
||||
} // namespace cli
|
||||
|
||||
@ -22,15 +21,65 @@ enum struct mode { training, translation, scoring, server };
|
||||
* @brief Command-line options parser
|
||||
*
|
||||
* New options and aliases should be defined within `addOptions*` methods.
|
||||
* ... unless they are specific to certain executables.
|
||||
* In that case, use a pattern like this (e.g., for a server):
|
||||
* int main(int argc, char* argv[]) {
|
||||
* ConfigParser cp(cli::mode::translation);
|
||||
* cp.addOption<int>("--port", // option name
|
||||
* "Server Options", // option group name
|
||||
* "Port for server.", // help string
|
||||
* 5678); // default value
|
||||
* auto opts = cp.parseOptions(argc,argv,true); // 'true' for validation
|
||||
* ...
|
||||
*
|
||||
*
|
||||
*/
|
||||
class ConfigParser {
|
||||
public:
|
||||
|
||||
ConfigParser(cli::mode mode);
|
||||
|
||||
ConfigParser(int argc, char** argv, cli::mode mode, bool validate = false)
|
||||
: modeServer_(mode == cli::mode::server),
|
||||
mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
|
||||
: ConfigParser(mode) {
|
||||
parseOptions(argc, argv, validate);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help,
|
||||
const T val) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help,val);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help,
|
||||
const T val,
|
||||
const T implicit_val) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help,val)->implicit_val(implicit_val);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ConfigParser&
|
||||
addOption(const std::string& args,
|
||||
const std::string& group,
|
||||
const std::string& help) {
|
||||
std::string previous_group = cli_.switchGroup(group);
|
||||
cli_.add<T>(args,help);
|
||||
cli_.switchGroup(previous_group);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Parse command-line options
|
||||
*
|
||||
@ -47,15 +96,18 @@ public:
|
||||
* @param argc
|
||||
* @param argv
|
||||
* @param validate Do or do not validate parsed options
|
||||
* @return (YAML::Node const&)config_
|
||||
*/
|
||||
void parseOptions(int argc, char** argv, bool validate);
|
||||
|
||||
YAML::Node getConfig() const;
|
||||
|
||||
Ptr<Options> parseOptions(int argc, char** argv, bool validate);
|
||||
YAML::Node const& getConfig() const;
|
||||
cli::mode getMode() const;
|
||||
std::string const& cmdLine() const;
|
||||
private:
|
||||
bool modeServer_;
|
||||
cli::CLIWrapper cli_;
|
||||
cli::mode mode_;
|
||||
YAML::Node config_;
|
||||
std::string cmdLine_;
|
||||
|
||||
// Check if the config contains value for option key
|
||||
bool has(const std::string& key) const {
|
||||
|
61
src/common/options.cpp
Normal file
61
src/common/options.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include "options.h"
|
||||
|
||||
namespace marian {
|
||||
Options::Options() {}
|
||||
|
||||
Options::Options(const Options& other)
|
||||
: options_(YAML::Clone(other.options_)) {}
|
||||
|
||||
Options Options::clone() const {
|
||||
return Options(*this);
|
||||
}
|
||||
|
||||
YAML::Node& Options::getYaml() {
|
||||
return options_;
|
||||
}
|
||||
|
||||
const YAML::Node& Options::getYaml() const {
|
||||
return options_;
|
||||
}
|
||||
|
||||
void Options::parse(const std::string& yaml) {
|
||||
auto node = YAML::Load(yaml);
|
||||
for(auto it : node)
|
||||
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
|
||||
}
|
||||
|
||||
void Options::merge(const YAML::Node& node, bool overwrite) {
|
||||
for(auto it : node)
|
||||
if(overwrite || !options_[it.first.as<std::string>()])
|
||||
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
|
||||
}
|
||||
|
||||
void Options::merge(Ptr<Options> options) {
|
||||
merge(options->getYaml());
|
||||
}
|
||||
|
||||
std::string Options::str() {
|
||||
std::stringstream ss;
|
||||
ss << options_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool Options::hasAndNotEmpty(const std::string& key) const {
|
||||
if(!has(key)) {
|
||||
return false;
|
||||
}
|
||||
if(options_[key].IsSequence()) {
|
||||
return options_[key].size() != 0;
|
||||
}
|
||||
try {
|
||||
return !options_[key].as<std::string>().empty();
|
||||
} catch(const YAML::BadConversion& e) {
|
||||
ABORT("Option '{}' is neither a sequence nor a text");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Options::has(const std::string& key) const {
|
||||
return options_[key];
|
||||
}
|
||||
}
|
@ -32,9 +32,9 @@ protected:
|
||||
YAML::Node options_;
|
||||
|
||||
public:
|
||||
Options() {}
|
||||
Options(const Options& other) : options_(YAML::Clone(other.options_)) {}
|
||||
|
||||
Options();
|
||||
Options(const Options& other);
|
||||
|
||||
// constructor with one or more key-value pairs
|
||||
// New<Options>("var1", val1, "var2", val2, ...)
|
||||
template <typename T, typename... Args>
|
||||
@ -54,16 +54,13 @@ public:
|
||||
/**
|
||||
* @brief Return a copy of the object that can be safely modified.
|
||||
*/
|
||||
Options clone() const { return Options(*this); }
|
||||
Options clone() const;
|
||||
|
||||
YAML::Node& getYaml() { return options_; }
|
||||
const YAML::Node& getYaml() const { return options_; }
|
||||
YAML::Node& getYaml();
|
||||
|
||||
void parse(const std::string& yaml) {
|
||||
auto node = YAML::Load(yaml);
|
||||
for(auto it : node)
|
||||
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
|
||||
}
|
||||
const YAML::Node& getYaml() const;
|
||||
|
||||
void parse(const std::string& yaml);
|
||||
|
||||
/**
|
||||
* @brief Splice options from a YAML node
|
||||
@ -74,20 +71,10 @@ public:
|
||||
* @param node a YAML node to transfer the options from
|
||||
* @param overwrite overwrite all options
|
||||
*/
|
||||
void merge(YAML::Node& node, bool overwrite = false) {
|
||||
for(auto it : node)
|
||||
if(overwrite || !options_[it.first.as<std::string>()])
|
||||
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
|
||||
}
|
||||
void merge(const YAML::Node& node, bool overwrite = false);
|
||||
void merge(Ptr<Options> options);
|
||||
|
||||
void merge(const YAML::Node& node, bool overwrite = false) { merge(node, overwrite); }
|
||||
void merge(Ptr<Options> options) { merge(options->getYaml()); }
|
||||
|
||||
std::string str() {
|
||||
std::stringstream ss;
|
||||
ss << options_;
|
||||
return ss.str();
|
||||
}
|
||||
std::string str();
|
||||
|
||||
template <typename T>
|
||||
void set(const std::string& key, T value) {
|
||||
@ -103,13 +90,13 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get(const std::string& key) {
|
||||
T get(const std::string& key) const {
|
||||
ABORT_IF(!has(key), "Required option '{}' has not been set", key);
|
||||
return options_[key].as<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get(const std::string& key, T defaultValue) {
|
||||
T get(const std::string& key, T defaultValue) const {
|
||||
if(has(key))
|
||||
return options_[key].as<T>();
|
||||
else
|
||||
@ -126,22 +113,9 @@ public:
|
||||
*
|
||||
* @return true if the option is defined and is a nonempty sequence or string
|
||||
*/
|
||||
bool hasAndNotEmpty(const std::string& key) const {
|
||||
if(!has(key)) {
|
||||
return false;
|
||||
}
|
||||
if(options_[key].IsSequence()) {
|
||||
return options_[key].size() != 0;
|
||||
}
|
||||
try {
|
||||
return !options_[key].as<std::string>().empty();
|
||||
} catch(const YAML::BadConversion&) {
|
||||
ABORT("Option '{}' is neither a sequence nor a text");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool hasAndNotEmpty(const std::string& key) const;
|
||||
|
||||
bool has(const std::string& key) const { return options_[key]; }
|
||||
bool has(const std::string& key) const;
|
||||
};
|
||||
|
||||
} // namespace marian
|
||||
|
@ -175,7 +175,17 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// reset gradients outside current shard
|
||||
auto reset = [this, shardSize](size_t idx, size_t begin, size_t end) {
|
||||
auto grad = graphs_[idx]->params()->grads();
|
||||
if (begin > 0)
|
||||
grad->subtensor(0, begin)->set(0);
|
||||
if (end < grad->size())
|
||||
grad->subtensor(end, grad->size()-end)->set(0);
|
||||
};
|
||||
|
||||
foreach(scatter);
|
||||
foreach(reset);
|
||||
}
|
||||
|
||||
void allGatherParams() const override {
|
||||
|
Loading…
Reference in New Issue
Block a user