merge with public master

This commit is contained in:
Marcin Junczys-Dowmunt 2019-10-25 22:24:59 -07:00
commit 1174cecbd6
12 changed files with 269 additions and 135 deletions

View File

@ -224,7 +224,14 @@ if(CUDA_FOUND)
endif()
else(CUDA_FOUND)
message(FATAL_ERROR "CUDA has not been found, set -DCOMPILE_CUDA=off to avoid this check and to compile the CPU version only")
message("
Cannot find suitable CUDA libraries. Specify the path explicitly with
-DCUDA_TOOLKIT_ROOT_DIR=/path/to/appropriate/cuda/installation
(hint: try /usr/local/$(readlink /usr/local/cuda))
OR compile the CPU-only version of Marian with
-DCOMPILE_CUDA=off
")
message(FATAL_ERROR "FATAL ERROR: No suitable CUDA library found.")
endif(CUDA_FOUND)
else(COMPILE_CUDA)

@ -1 +1 @@
Subproject commit 08c1485c2c925e4e1ed8a3e5df686ed5c7dc496e
Subproject commit 95b66e74d47107a2e3abad5a4c5338904129a25a

View File

@ -17,6 +17,7 @@ add_library(marian STATIC
common/config_parser.cpp
common/aliases.cpp
common/config_validator.cpp
common/options.cpp
common/binary.cpp
common/io.cpp
common/filesystem.cpp

View File

@ -110,8 +110,12 @@ CLIWrapper::CLIWrapper(Ptr<marian::Options> options,
CLIWrapper::~CLIWrapper() {}
void CLIWrapper::switchGroup(const std::string &name) {
currentGroup_ = name.empty() ? defaultGroup_ : name;
// set current group to name, return previous group
std::string CLIWrapper::switchGroup(std::string name) {
currentGroup_.swap(name);
if (currentGroup_.empty())
currentGroup_ = defaultGroup_;
return name;
}
void CLIWrapper::parse(int argc, char **argv) {

View File

@ -212,8 +212,9 @@ public:
* Switch to different option group or to the default group if argument is empty.
*
* @param name Header of the option group
* @return Previous group.
*/
void switchGroup(const std::string &name = "");
std::string switchGroup(std::string name = "");
// Parse command-line arguments. Handles --help and --version options
void parse(int argc, char **argv);

View File

@ -1,9 +1,11 @@
#include "common/config.h"
#include "common/config_parser.h"
#include "common/file_stream.h"
#include "common/logging.h"
#include "common/options.h"
#include "common/regex.h"
#include "common/utils.h"
#include "common/version.h"
#include "common/regex.h"
#include <algorithm>
#include <set>
@ -14,35 +16,26 @@ namespace marian {
// @TODO: keep seed in a single place, now it is kept here and in Config/Options
size_t Config::seed = (size_t)time(0);
Config::Config(int argc,
char** argv,
cli::mode mode,
bool validate /*= true*/) {
initialize(argc, argv, mode, validate);
Config::Config(ConfigParser const& cp) {
initialize(cp);
}
Config::Config(int argc, char** argv, cli::mode mode, bool validate /*= true*/)
: Config(ConfigParser(argc, argv, mode, validate)) {}
Config::Config(const Config& other) : config_(YAML::Clone(other.config_)) {}
Config::Config(const Options& options) : config_(YAML::Clone(options.getYaml())) {}
void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
auto parser = ConfigParser(argc, argv, mode, validate);
config_ = parser.getConfig();
void Config::initialize(ConfigParser const& cp) {
config_ = YAML::Clone(cp.getConfig());
cli::mode mode = cp.getMode();
createLoggers(this);
// echo version and command line
LOG(info, "[marian] Marian {}", buildVersion());
std::string cmdLine;
for(int i = 0; i < argc; i++) {
std::string arg = argv[i];
std::string quote; // attempt to quote special chars
if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
quote = "'";
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
if(!cmdLine.empty())
cmdLine.push_back(' ');
cmdLine += quote + arg + quote;
}
std::string cmdLine = cp.cmdLine();
std::string hostname; int pid; std::tie
(hostname, pid) = utils::hostnameAndProcessId();
LOG(info, "[marian] Running on {} as process {} with command line:", hostname, pid);
@ -262,14 +255,17 @@ std::vector<DeviceId> Config::getDevices(Ptr<Options> options,
return devices;
}
Ptr<Options> parseOptions(int argc,
char** argv,
cli::mode mode,
bool validate /*= true*/) {
auto config = New<Config>(argc, argv, mode, validate);
auto options = New<Options>();
options->merge(config->get());
return options;
Ptr<Options>
parseOptions(int argc, char** argv, cli::mode mode, bool validate){
ConfigParser cp(mode);
return cp.parseOptions(argc, argv, validate);
}
std::ostream& operator<<(std::ostream& out, const Config& config) {
YAML::Emitter outYaml;
cli::OutputYaml(config.get(), outYaml);
out << outYaml.c_str();
return out;
}
} // namespace marian

View File

@ -38,6 +38,7 @@ public:
typedef YAML::Node YamlNode;
Config(ConfigParser const& cp);
// TODO: remove mode from this class
Config(int argc,
char** argv,
@ -47,7 +48,7 @@ public:
Config(const Config& other);
Config(const Options& options);
void initialize(int argc, char** argv, cli::mode mode, bool validate);
void initialize(ConfigParser const& cp);
bool has(const std::string& key) const;
@ -83,12 +84,7 @@ public:
void save(const std::string& name);
friend std::ostream& operator<<(std::ostream& out, const Config& config) {
YAML::Emitter outYaml;
cli::OutputYaml(config.get(), outYaml);
out << outYaml.c_str();
return out;
}
friend std::ostream& operator<<(std::ostream& out, const Config& config);
static std::vector<DeviceId> getDevices(Ptr<Options> options,
size_t myMPIRank = 0,

View File

@ -1,12 +1,13 @@
#include "common/config_parser.h"
#include "common/definitions.h"
#include "common/cli_helper.h"
#include "common/config.h"
#include "common/config_parser.h"
#include "common/config_validator.h"
#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/logging.h"
#include "common/options.h"
#include "common/regex.h"
#include "common/utils.h"
#include <algorithm>
#include <set>
#include <stdexcept>
@ -49,6 +50,56 @@ const std::set<std::string> PATHS = {
};
// clang-format on
std::string escapeCmdLine(int argc, char** argv){
std::string cmdLine;
for(int i = 0; i < argc; i++) {
std::string arg = argv[i];
std::string quote; // attempt to quote special chars
if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
quote = "'";
arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
if(!cmdLine.empty())
cmdLine.push_back(' ');
cmdLine += quote + arg + quote;
}
return cmdLine;
}
std::string const& ConfigParser::cmdLine() const {
return cmdLine_;
}
ConfigParser::ConfigParser(cli::mode mode)
: cli_(config_,"Marian: Fast Neural Machine Translation in C++",
"General options", "", 40),
mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
addOptionsGeneral(cli_);
if (mode == cli::mode::server)
addOptionsServer(cli_);
addOptionsModel(cli_);
// clang-format off
switch(mode_) {
case cli::mode::training:
addOptionsTraining(cli_);
addOptionsValidation(cli_);
break;
case cli::mode::translation:
addOptionsTranslation(cli_);
break;
case cli::mode::scoring:
addOptionsScoring(cli_);
break;
default:
ABORT("wrong CLI mode");
break;
}
addAliases(cli_);
// clang-format on
}
void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
int defaultWorkspace = (mode_ == cli::mode::translation) ? 512 : 2048;
@ -87,14 +138,16 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
void ConfigParser::addOptionsServer(cli::CLIWrapper& cli) {
// clang-format off
auto previous_group = cli.switchGroup("Server options");
cli.add<size_t>("--port,-p",
"Port number for web socket server",
8080);
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.switchGroup("Model options");
auto previous_group = cli.switchGroup("Model options");
// clang-format off
if(mode_ == cli::mode::translation) {
@ -259,11 +312,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.add<float>("--transformer-dropout-ffn",
"Dropout for transformer filter (0 = no dropout)");
}
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.switchGroup("Training options");
auto previous_group = cli.switchGroup("Training options");
// clang-format off
cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
@ -447,11 +501,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<std::vector<std::string>>("--task",
"Use predefined set of options. Possible values: transformer, transformer-big");
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
cli.switchGroup("Validation set options");
auto previous_group = cli.switchGroup("Validation set options");
// clang-format off
cli.add<std::vector<std::string>>("--valid-sets",
@ -508,11 +563,12 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Keep best model for each validation metric");
cli.add<std::string>("--valid-log",
"Log validation scores to file given by arg");
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
cli.switchGroup("Translator options");
auto previous_group = cli.switchGroup("Translator options");
// clang-format off
cli.add<std::vector<std::string>>("--input,-i",
@ -572,11 +628,12 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
// add ULR settings
addSuboptionsULR(cli);
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
cli.switchGroup("Scorer options");
auto previous_group = cli.switchGroup("Scorer options");
// clang-format off
cli.add<bool>("--no-reload",
@ -610,15 +667,13 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
//@TODO: gemm-type missing? fix this
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision for inference, set parameter type in expression graph",
{"float32"});
cli.switchGroup(previous_group);
// clang-format on
}
@ -747,46 +802,20 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
// clang-format on
}
void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
cli::CLIWrapper cli(config_,
"Marian: Fast Neural Machine Translation in C++",
"General options",
"",
40);
addOptionsGeneral(cli);
if(modeServer_)
addOptionsServer(cli);
addOptionsModel(cli);
cli::mode ConfigParser::getMode() const { return mode_; }
// clang-format off
switch(mode_) {
case cli::mode::training:
addOptionsTraining(cli);
addOptionsValidation(cli);
break;
case cli::mode::translation:
addOptionsTranslation(cli);
break;
case cli::mode::scoring:
addOptionsScoring(cli);
break;
default:
ABORT("wrong CLI mode");
break;
}
// clang-format on
addAliases(cli);
Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
cmdLine_ = escapeCmdLine(argc,argv);
// parse command-line options and fill wrapped YAML config
cli.parse(argc, argv);
cli_.parse(argc, argv);
// get paths to extra config files
auto configPaths = findConfigPaths();
if(!configPaths.empty()) {
auto config = loadConfigFiles(configPaths);
cli.updateConfig(config,
cli_.updateConfig(config,
cli::OptionPriority::ConfigFile,
"There are option(s) in a config file that are not expected");
}
@ -807,15 +836,18 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
config_.remove("dump-config");
if(dumpMode == "expand") {
cli.parseAliases();
cli_.parseAliases();
}
bool minimal = (dumpMode == "minimal" || dumpMode == "expand");
std::cout << cli.dumpConfig(minimal) << std::endl;
std::cout << cli_.dumpConfig(minimal) << std::endl;
exit(0);
}
cli.parseAliases();
cli_.parseAliases();
auto opts = New<Options>();
opts->merge(Config(*this).get());
return opts;
}
std::vector<std::string> ConfigParser::findConfigPaths() {
@ -877,7 +909,7 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
return configAll;
}
YAML::Node ConfigParser::getConfig() const {
const YAML::Node& ConfigParser::getConfig() const {
return config_;
}
} // namespace marian

View File

@ -14,7 +14,6 @@
namespace marian {
namespace cli {
// CLI mode
enum struct mode { training, translation, scoring, server };
} // namespace cli
@ -22,15 +21,65 @@ enum struct mode { training, translation, scoring, server };
* @brief Command-line options parser
*
* New options and aliases should be defined within `addOptions*` methods.
* ... unless they are specific to certain executables.
* In that case, use a pattern like this (e.g., for a server):
* int main(int argc, char* argv[]) {
* ConfigParser cp(cli::mode::translation);
* cp.addOption<int>("--port", // option name
* "Server Options", // option group name
* "Port for server.", // help string
* 5678); // default value
* auto opts = cp.parseOptions(argc,argv,true); // 'true' for validation
* ...
*
*
*/
class ConfigParser {
public:
ConfigParser(cli::mode mode);
ConfigParser(int argc, char** argv, cli::mode mode, bool validate = false)
: modeServer_(mode == cli::mode::server),
mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
: ConfigParser(mode) {
parseOptions(argc, argv, validate);
}
template<typename T>
ConfigParser&
addOption(const std::string& args,
const std::string& group,
const std::string& help,
const T val) {
std::string previous_group = cli_.switchGroup(group);
cli_.add<T>(args,help,val);
cli_.switchGroup(previous_group);
return *this;
}
template<typename T>
ConfigParser&
addOption(const std::string& args,
const std::string& group,
const std::string& help,
const T val,
const T implicit_val) {
std::string previous_group = cli_.switchGroup(group);
cli_.add<T>(args,help,val)->implicit_val(implicit_val);
cli_.switchGroup(previous_group);
return *this;
}
template<typename T>
ConfigParser&
addOption(const std::string& args,
const std::string& group,
const std::string& help) {
std::string previous_group = cli_.switchGroup(group);
cli_.add<T>(args,help);
cli_.switchGroup(previous_group);
return *this;
}
/**
* @brief Parse command-line options
*
@ -47,15 +96,18 @@ public:
* @param argc
* @param argv
* @param validate Do or do not validate parsed options
* @return (YAML::Node const&)config_
*/
void parseOptions(int argc, char** argv, bool validate);
YAML::Node getConfig() const;
Ptr<Options> parseOptions(int argc, char** argv, bool validate);
YAML::Node const& getConfig() const;
cli::mode getMode() const;
std::string const& cmdLine() const;
private:
bool modeServer_;
cli::CLIWrapper cli_;
cli::mode mode_;
YAML::Node config_;
std::string cmdLine_;
// Check if the config contains value for option key
bool has(const std::string& key) const {

61
src/common/options.cpp Normal file
View File

@ -0,0 +1,61 @@
#include "options.h"
namespace marian {
Options::Options() {}
Options::Options(const Options& other)
: options_(YAML::Clone(other.options_)) {}
Options Options::clone() const {
return Options(*this);
}
YAML::Node& Options::getYaml() {
return options_;
}
const YAML::Node& Options::getYaml() const {
return options_;
}
void Options::parse(const std::string& yaml) {
auto node = YAML::Load(yaml);
for(auto it : node)
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
}
void Options::merge(const YAML::Node& node, bool overwrite) {
for(auto it : node)
if(overwrite || !options_[it.first.as<std::string>()])
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
}
void Options::merge(Ptr<Options> options) {
merge(options->getYaml());
}
std::string Options::str() {
std::stringstream ss;
ss << options_;
return ss.str();
}
bool Options::hasAndNotEmpty(const std::string& key) const {
if(!has(key)) {
return false;
}
if(options_[key].IsSequence()) {
return options_[key].size() != 0;
}
try {
return !options_[key].as<std::string>().empty();
} catch(const YAML::BadConversion& e) {
ABORT("Option '{}' is neither a sequence nor a text");
}
return false;
}
bool Options::has(const std::string& key) const {
return options_[key];
}
}

View File

@ -32,9 +32,9 @@ protected:
YAML::Node options_;
public:
Options() {}
Options(const Options& other) : options_(YAML::Clone(other.options_)) {}
Options();
Options(const Options& other);
// constructor with one or more key-value pairs
// New<Options>("var1", val1, "var2", val2, ...)
template <typename T, typename... Args>
@ -54,16 +54,13 @@ public:
/**
* @brief Return a copy of the object that can be safely modified.
*/
Options clone() const { return Options(*this); }
Options clone() const;
YAML::Node& getYaml() { return options_; }
const YAML::Node& getYaml() const { return options_; }
YAML::Node& getYaml();
void parse(const std::string& yaml) {
auto node = YAML::Load(yaml);
for(auto it : node)
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
}
const YAML::Node& getYaml() const;
void parse(const std::string& yaml);
/**
* @brief Splice options from a YAML node
@ -74,20 +71,10 @@ public:
* @param node a YAML node to transfer the options from
* @param overwrite overwrite all options
*/
void merge(YAML::Node& node, bool overwrite = false) {
for(auto it : node)
if(overwrite || !options_[it.first.as<std::string>()])
options_[it.first.as<std::string>()] = YAML::Clone(it.second);
}
void merge(const YAML::Node& node, bool overwrite = false);
void merge(Ptr<Options> options);
void merge(const YAML::Node& node, bool overwrite = false) { merge(node, overwrite); }
void merge(Ptr<Options> options) { merge(options->getYaml()); }
std::string str() {
std::stringstream ss;
ss << options_;
return ss.str();
}
std::string str();
template <typename T>
void set(const std::string& key, T value) {
@ -103,13 +90,13 @@ public:
}
template <typename T>
T get(const std::string& key) {
T get(const std::string& key) const {
ABORT_IF(!has(key), "Required option '{}' has not been set", key);
return options_[key].as<T>();
}
template <typename T>
T get(const std::string& key, T defaultValue) {
T get(const std::string& key, T defaultValue) const {
if(has(key))
return options_[key].as<T>();
else
@ -126,22 +113,9 @@ public:
*
* @return true if the option is defined and is a nonempty sequence or string
*/
bool hasAndNotEmpty(const std::string& key) const {
if(!has(key)) {
return false;
}
if(options_[key].IsSequence()) {
return options_[key].size() != 0;
}
try {
return !options_[key].as<std::string>().empty();
} catch(const YAML::BadConversion&) {
ABORT("Option '{}' is neither a sequence nor a text");
}
return false;
}
bool hasAndNotEmpty(const std::string& key) const;
bool has(const std::string& key) const { return options_[key]; }
bool has(const std::string& key) const;
};
} // namespace marian

View File

@ -175,7 +175,17 @@ public:
}
};
// reset gradients outside current shard
auto reset = [this, shardSize](size_t idx, size_t begin, size_t end) {
auto grad = graphs_[idx]->params()->grads();
if (begin > 0)
grad->subtensor(0, begin)->set(0);
if (end < grad->size())
grad->subtensor(end, grad->size()-end)->set(0);
};
foreach(scatter);
foreach(reset);
}
void allGatherParams() const override {