mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-27 10:33:14 +03:00
Towards YAML configurations
This commit is contained in:
parent
273487f496
commit
7e76b61f88
@ -18,6 +18,13 @@ else(Boost_FOUND)
|
||||
message(SEND_ERROR "Cannot find Boost libraries. Terminating." )
|
||||
endif(Boost_FOUND)
|
||||
|
||||
find_package (YamlCpp)
|
||||
if (YAMLCPP_FOUND)
|
||||
include_directories(${YAMLCPP_INCLUDE_DIRS})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${YAMLCPP_LIBRARY})
|
||||
endif (YAMLCPP_FOUND)
|
||||
|
||||
|
||||
set(KENLM CACHE STRING "Path to compiled kenlm directory")
|
||||
if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
|
||||
message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
|
||||
|
98
cmake/FindYamlCpp.cmake
Normal file
98
cmake/FindYamlCpp.cmake
Normal file
@ -0,0 +1,98 @@
|
||||
# Locate yaml-cpp
|
||||
#
|
||||
# This module defines
|
||||
# YAMLCPP_FOUND, if false, do not try to link to yaml-cpp
|
||||
# YAMLCPP_LIBNAME, name of yaml library
|
||||
# YAMLCPP_LIBRARY, where to find yaml-cpp
|
||||
# YAMLCPP_LIBRARY_RELEASE, where to find Release or RelWithDebInfo yaml-cpp
|
||||
# YAMLCPP_LIBRARY_DEBUG, where to find Debug yaml-cpp
|
||||
# YAMLCPP_INCLUDE_DIR, where to find yaml.h
|
||||
# YAMLCPP_LIBRARY_DIR, the directories to find YAMLCPP_LIBRARY
|
||||
#
|
||||
# By default, the dynamic libraries of yaml-cpp will be found. To find the static ones instead,
|
||||
# you must set the YAMLCPP_USE_STATIC_LIBS variable to TRUE before calling find_package(YamlCpp ...)
|
||||
|
||||
# attempt to find static library first if this is set
|
||||
if(YAMLCPP_USE_STATIC_LIBS)
|
||||
set(YAMLCPP_STATIC libyaml-cpp.a)
|
||||
set(YAMLCPP_STATIC_DEBUG libyaml-cpp-dbg.a)
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") ### Set Yaml libary name for Windows
|
||||
set(YAMLCPP_LIBNAME "libyaml-cppmd" CACHE STRING "Name of YAML library")
|
||||
set(YAMLCPP_LIBNAME optimized ${YAMLCPP_LIBNAME} debug ${YAMLCPP_LIBNAME}d)
|
||||
else() ### Set Yaml libary name for Unix, Linux, OS X, etc
|
||||
set(YAMLCPP_LIBNAME "yaml-cpp" CACHE STRING "Name of YAML library")
|
||||
endif()
|
||||
|
||||
# find the yaml-cpp include directory
|
||||
find_path(YAMLCPP_INCLUDE_DIR
|
||||
NAMES yaml-cpp/yaml.h
|
||||
PATH_SUFFIXES include
|
||||
PATHS
|
||||
${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/include
|
||||
~/Library/Frameworks/yaml-cpp/include/
|
||||
/Library/Frameworks/yaml-cpp/include/
|
||||
/usr/local/include/
|
||||
/usr/include/
|
||||
/sw/yaml-cpp/ # Fink
|
||||
/opt/local/yaml-cpp/ # DarwinPorts
|
||||
/opt/csw/yaml-cpp/ # Blastwave
|
||||
/opt/yaml-cpp/)
|
||||
|
||||
# find the release yaml-cpp library
|
||||
find_library(YAMLCPP_LIBRARY_RELEASE
|
||||
NAMES ${YAMLCPP_STATIC} yaml-cpp libyaml-cppmd.lib
|
||||
PATH_SUFFIXES lib64 lib Release RelWithDebInfo
|
||||
PATHS
|
||||
${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/
|
||||
${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/build
|
||||
~/Library/Frameworks
|
||||
/Library/Frameworks
|
||||
/usr/local
|
||||
/usr
|
||||
/sw
|
||||
/opt/local
|
||||
/opt/csw
|
||||
/opt)
|
||||
|
||||
# find the debug yaml-cpp library
|
||||
find_library(YAMLCPP_LIBRARY_DEBUG
|
||||
NAMES ${YAMLCPP_STATIC_DEBUG} yaml-cpp-dbg libyaml-cppmdd.lib
|
||||
PATH_SUFFIXES lib64 lib Debug
|
||||
PATHS
|
||||
${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/
|
||||
${PROJECT_SOURCE_DIR}/dependencies/yaml-cpp-0.5.1/build
|
||||
~/Library/Frameworks
|
||||
/Library/Frameworks
|
||||
/usr/local
|
||||
/usr
|
||||
/sw
|
||||
/opt/local
|
||||
/opt/csw
|
||||
/opt)
|
||||
|
||||
# set library vars
|
||||
set(YAMLCPP_LIBRARY ${YAMLCPP_LIBRARY_RELEASE})
|
||||
if(CMAKE_BUILD_TYPE MATCHES Debug AND EXISTS ${YAMLCPP_LIBRARY_DEBUG})
|
||||
set(YAMLCPP_LIBRARY ${YAMLCPP_LIBRARY_DEBUG})
|
||||
endif()
|
||||
|
||||
get_filename_component(YAMLCPP_LIBRARY_RELEASE_DIR ${YAMLCPP_LIBRARY_RELEASE} PATH)
|
||||
get_filename_component(YAMLCPP_LIBRARY_DEBUG_DIR ${YAMLCPP_LIBRARY_DEBUG} PATH)
|
||||
set(YAMLCPP_LIBRARY_DIR ${YAMLCPP_LIBRARY_RELEASE_DIR} ${YAMLCPP_LIBRARY_DEBUG_DIR})
|
||||
|
||||
# handle the QUIETLY and REQUIRED arguments and set YAMLCPP_FOUND to TRUE if all listed variables are TRUE
|
||||
include(FindPackageHandleStandardArgs)
|
||||
FIND_PACKAGE_HANDLE_STANDARD_ARGS(YamlCpp DEFAULT_MSG
|
||||
YAMLCPP_INCLUDE_DIR
|
||||
YAMLCPP_LIBRARY
|
||||
YAMLCPP_LIBRARY_DIR)
|
||||
mark_as_advanced(
|
||||
YAMLCPP_INCLUDE_DIR
|
||||
YAMLCPP_LIBRARY_DIR
|
||||
YAMLCPP_LIBRARY
|
||||
YAMLCPP_LIBRARY_RELEASE
|
||||
YAMLCPP_LIBRARY_RELEASE_DIR
|
||||
YAMLCPP_LIBRARY_DEBUG
|
||||
YAMLCPP_LIBRARY_DEBUG_DIR)
|
@ -15,6 +15,7 @@ add_library(librescorer OBJECT
|
||||
)
|
||||
|
||||
add_library(libamunn OBJECT
|
||||
decoder/config.cpp
|
||||
decoder/kenlm.cpp
|
||||
)
|
||||
|
||||
@ -39,7 +40,7 @@ cuda_add_executable(
|
||||
)
|
||||
|
||||
foreach(exec amunn rescorer)
|
||||
target_link_libraries(${exec} ${EXT_LIBS})
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
|
||||
endforeach(exec)
|
||||
|
173
src/decoder/config.cpp
Normal file
173
src/decoder/config.cpp
Normal file
@ -0,0 +1,173 @@
|
||||
#include <set>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#define SET_OPTION(key, type) \
|
||||
if(!vm_[key].defaulted() || !config_[key]) { \
|
||||
config_[key] = vm_[key].as<type>(); \
|
||||
}
|
||||
|
||||
#define SET_OPTION_NONDEFAULT(key, type) \
|
||||
if(vm_.count(key) > 0) { \
|
||||
config_[key] = vm_[key].as<type>(); \
|
||||
}
|
||||
|
||||
bool Config::Has(const std::string& key) {
|
||||
return config_[key];
|
||||
}
|
||||
|
||||
YAML::Node& Config::Get() {
|
||||
return config_;
|
||||
}
|
||||
|
||||
void Config::AddOptions(size_t argc, char** argv) {
|
||||
namespace po = boost::program_options;
|
||||
po::options_description general("General options");
|
||||
|
||||
std::string configPath;
|
||||
std::vector<size_t> devices;
|
||||
std::vector<size_t> tabMap;
|
||||
std::vector<float> weights;
|
||||
|
||||
std::vector<std::string> modelPaths;
|
||||
std::vector<std::string> lmPaths;
|
||||
std::vector<std::string> sourceVocabPaths;
|
||||
std::string targetVocabPath;
|
||||
|
||||
general.add_options()
|
||||
("config,c", po::value(&configPath),
|
||||
"Configuration file")
|
||||
("model,m", po::value(&modelPaths)->multitoken()->required(),
|
||||
"Path to neural translation model(s)")
|
||||
("source,s", po::value(&sourceVocabPaths)->multitoken()->required(),
|
||||
"Path to source vocabulary file.")
|
||||
("target,t", po::value(&targetVocabPath)->required(),
|
||||
"Path to target vocabulary file.")
|
||||
("ape", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Add APE-penalty")
|
||||
("lm,l", po::value(&lmPaths)->multitoken(),
|
||||
"Path to KenLM language model(s)")
|
||||
("tab-map", po::value(&tabMap)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"tab map")
|
||||
("devices,d", po::value(&devices)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"CUDA device(s) to use, set to 0 by default, "
|
||||
"e.g. set to 0 1 to use gpu0 and gpu1. "
|
||||
"Implicitly sets minimal number of threads to number of devices.")
|
||||
("threads-per-device", po::value<size_t>()->default_value(1),
|
||||
"Number of threads per device, total thread count equals threads x devices")
|
||||
("help,h", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Print this help message and exit")
|
||||
;
|
||||
|
||||
po::options_description search("Search options");
|
||||
search.add_options()
|
||||
("beam-size,b", po::value<size_t>()->default_value(12),
|
||||
"Decoding beam-size")
|
||||
("normalize,n", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Normalize scores by translation length after decoding")
|
||||
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output n-best list with n = beam-size")
|
||||
("weights,w", po::value(&weights)->multitoken()->default_value(std::vector<float>(1, 1.0), "1.0"),
|
||||
"Model weights (for neural models and KenLM models)")
|
||||
("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output used weights to stdout and exit")
|
||||
("load-weights", po::value<std::string>(),
|
||||
"Load scorer weights from this file")
|
||||
;
|
||||
|
||||
po::options_description kenlm("KenLM specific options");
|
||||
kenlm.add_options()
|
||||
("kenlm-batch-size", po::value<size_t>()->default_value(1000),
|
||||
"Batch size for batched queries to KenLM")
|
||||
("kenlm-batch-threads", po::value<size_t>()->default_value(4),
|
||||
"Concurrent worker threads for batch processing")
|
||||
;
|
||||
|
||||
po::options_description cmdline_options("Allowed options");
|
||||
cmdline_options.add(general);
|
||||
cmdline_options.add(search);
|
||||
cmdline_options.add(kenlm);
|
||||
|
||||
po::variables_map vm_;
|
||||
try {
|
||||
po::store(po::command_line_parser(argc,argv)
|
||||
.options(cmdline_options).run(), vm_);
|
||||
po::notify(vm_);
|
||||
}
|
||||
catch (std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl << std::endl;
|
||||
|
||||
std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cerr << cmdline_options << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (vm_["help"].as<bool>()) {
|
||||
std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cerr << cmdline_options << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(configPath.size())
|
||||
config_ = YAML::LoadFile(configPath);
|
||||
|
||||
SET_OPTION("model", std::vector<std::string>)
|
||||
SET_OPTION_NONDEFAULT("lm", std::vector<std::string>)
|
||||
SET_OPTION("ape", bool)
|
||||
SET_OPTION("source", std::vector<std::string>)
|
||||
SET_OPTION("target", std::string)
|
||||
|
||||
SET_OPTION("n-best", bool)
|
||||
SET_OPTION("normalize", bool)
|
||||
SET_OPTION("beam-size", size_t)
|
||||
SET_OPTION("threads-per-device", size_t)
|
||||
SET_OPTION("devices", std::vector<size_t>)
|
||||
SET_OPTION("tab-map", std::vector<size_t>)
|
||||
|
||||
SET_OPTION("weights", std::vector<float>)
|
||||
SET_OPTION("show-weights", bool)
|
||||
SET_OPTION_NONDEFAULT("load-weights", std::string)
|
||||
|
||||
SET_OPTION("kenlm-batch-size", size_t)
|
||||
SET_OPTION("kenlm-batch-threads", size_t)
|
||||
}
|
||||
|
||||
void OutputRec(const YAML::Node node, YAML::Emitter& out) {
|
||||
std::set<std::string> flow = { "weights", "devices", "tab-map" };
|
||||
std::set<std::string> sorter;
|
||||
switch (node.Type()) {
|
||||
case YAML::NodeType::Null:
|
||||
out << node; break;
|
||||
case YAML::NodeType::Scalar:
|
||||
out << node; break;
|
||||
case YAML::NodeType::Sequence:
|
||||
out << YAML::BeginSeq;
|
||||
for(auto&& n : node)
|
||||
OutputRec(n, out);
|
||||
out << YAML::EndSeq;
|
||||
break;
|
||||
case YAML::NodeType::Map:
|
||||
for(auto& n : node)
|
||||
sorter.insert(n.first.as<std::string>());
|
||||
out << YAML::BeginMap;
|
||||
for(auto& key : sorter) {
|
||||
out << YAML::Key;
|
||||
out << key;
|
||||
out << YAML::Value;
|
||||
if(flow.count(key))
|
||||
out << YAML::Flow;
|
||||
OutputRec(node[key], out);
|
||||
}
|
||||
out << YAML::EndMap;
|
||||
break;
|
||||
case YAML::NodeType::Undefined:
|
||||
out << node; break;
|
||||
}
|
||||
}
|
||||
|
||||
void Config::LogOptions() {
|
||||
std::stringstream ss;
|
||||
YAML::Emitter out;
|
||||
OutputRec(config_, out);
|
||||
LOG(info) << "Options: \n" << out.c_str();
|
||||
}
|
31
src/decoder/config.h
Normal file
31
src/decoder/config.h
Normal file
@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include <yaml-cpp/yaml.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "logging.h"
|
||||
|
||||
class Config {
|
||||
private:
|
||||
YAML::Node config_;
|
||||
|
||||
public:
|
||||
bool Has(const std::string& key);
|
||||
|
||||
template <typename T>
|
||||
T Get(const std::string& key) {
|
||||
return config_[key].as<T>();
|
||||
}
|
||||
|
||||
YAML::Node& Get();
|
||||
|
||||
void AddOptions(size_t argc, char** argv);
|
||||
|
||||
template <class OStream>
|
||||
friend OStream& operator<<(OStream& out, const Config& config) {
|
||||
out << config.config_;
|
||||
return out;
|
||||
}
|
||||
|
||||
void LogOptions();
|
||||
};
|
@ -1,7 +1,10 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#include "god.h"
|
||||
#include "config.h"
|
||||
#include "scorer.h"
|
||||
#include "threadpool.h"
|
||||
#include "encoder_decoder.h"
|
||||
@ -10,16 +13,6 @@
|
||||
|
||||
God God::instance_;
|
||||
|
||||
God& God::Init(const std::string& initString) {
|
||||
std::vector<std::string> args = po::split_unix(initString);
|
||||
int argc = args.size() + 1;
|
||||
char* argv[argc];
|
||||
argv[0] = const_cast<char*>("dummy");
|
||||
for(int i = 1; i < argc; i++)
|
||||
argv[i] = const_cast<char*>(args[i-1].c_str());
|
||||
return Init(argc, argv);
|
||||
}
|
||||
|
||||
God& God::Init(int argc, char** argv) {
|
||||
return Summon().NonStaticInit(argc, argv);
|
||||
}
|
||||
@ -31,103 +24,24 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
progress_ = spdlog::stderr_logger_mt("progress");
|
||||
progress_->set_pattern("%v");
|
||||
|
||||
po::options_description general("General options");
|
||||
|
||||
std::vector<size_t> devices;
|
||||
std::vector<std::string> modelPaths;
|
||||
std::vector<std::string> lmPaths;
|
||||
std::vector<std::string> sourceVocabPaths;
|
||||
std::string targetVocabPath;
|
||||
|
||||
general.add_options()
|
||||
("model,m", po::value(&modelPaths)->multitoken()->required(),
|
||||
"Path to neural translation model(s)")
|
||||
("source,s", po::value(&sourceVocabPaths)->multitoken()->required(),
|
||||
"Path to source vocabulary file.")
|
||||
("target,t", po::value(&targetVocabPath)->required(),
|
||||
"Path to target vocabulary file.")
|
||||
("ape", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Add APE-penalty")
|
||||
("lm,l", po::value(&lmPaths)->multitoken(),
|
||||
"Path to KenLM language model(s)")
|
||||
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"tab map")
|
||||
("devices,d", po::value(&devices)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"CUDA device(s) to use, set to 0 by default, "
|
||||
"e.g. set to 0 1 to use gpu0 and gpu1. "
|
||||
"Implicitly sets minimal number of threads to number of devices.")
|
||||
("threads-per-device", po::value<size_t>()->default_value(1),
|
||||
"Number of threads per device, total thread count equals threads x devices")
|
||||
("help,h", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Print this help message and exit")
|
||||
;
|
||||
|
||||
po::options_description search("Search options");
|
||||
search.add_options()
|
||||
("beam-size,b", po::value<size_t>()->default_value(12),
|
||||
"Decoding beam-size")
|
||||
("normalize,n", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Normalize scores by translation length after decoding")
|
||||
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output n-best list with n = beam-size")
|
||||
("weights,w", po::value(&weights_)->multitoken()->default_value(std::vector<float>(1, 1.0), "1.0"),
|
||||
"Model weights (for neural models and KenLM models)")
|
||||
("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output used weights to stdout and exit")
|
||||
("load-weights", po::value<std::string>(),
|
||||
"Load scorer weights from this file")
|
||||
;
|
||||
|
||||
po::options_description kenlm("KenLM specific options");
|
||||
kenlm.add_options()
|
||||
("kenlm-batch-size", po::value<size_t>()->default_value(1000),
|
||||
"Batch size for batched queries to KenLM")
|
||||
("kenlm-batch-threads", po::value<size_t>()->default_value(4),
|
||||
"Concurrent worker threads for batch processing")
|
||||
;
|
||||
|
||||
po::options_description cmdline_options("Allowed options");
|
||||
cmdline_options.add(general);
|
||||
cmdline_options.add(search);
|
||||
cmdline_options.add(kenlm);
|
||||
|
||||
try {
|
||||
po::store(po::command_line_parser(argc,argv)
|
||||
.options(cmdline_options).run(), vm_);
|
||||
po::notify(vm_);
|
||||
}
|
||||
catch (std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl << std::endl;
|
||||
|
||||
std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cerr << cmdline_options << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (Get<bool>("help")) {
|
||||
std::cerr << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cerr << cmdline_options << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
PrintConfig();
|
||||
config_.AddOptions(argc, argv);
|
||||
config_.LogOptions();
|
||||
|
||||
for(auto& sourceVocabPath : sourceVocabPaths)
|
||||
for(auto sourceVocabPath : Get<std::vector<std::string>>("source"))
|
||||
sourceVocabs_.emplace_back(new Vocab(sourceVocabPath));
|
||||
targetVocab_.reset(new Vocab(targetVocabPath));
|
||||
|
||||
if(devices.empty()) {
|
||||
LOG(info) << "empty";
|
||||
devices.push_back(0);
|
||||
}
|
||||
targetVocab_.reset(new Vocab(Get<std::string>("target")));
|
||||
|
||||
auto modelPaths = Get<std::vector<std::string>>("model");
|
||||
|
||||
tabMap_ = Get<std::vector<size_t>>("tab-map");
|
||||
if(tabMap_.size() < modelPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More neural models than weights, setting missing tabs to 0";
|
||||
LOG(info) << "More neural models than tabs, setting missing tabs to 0";
|
||||
tabMap_.resize(modelPaths.size(), 0);
|
||||
}
|
||||
|
||||
// @TODO: handle this better!
|
||||
weights_ = Get<std::vector<float>>("weights");
|
||||
if(weights_.size() < modelPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More neural models than weights, setting weights to 1.0";
|
||||
@ -139,11 +53,11 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
weights_.resize(modelPaths.size(), 1.0);
|
||||
}
|
||||
|
||||
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More KenLM models than weights, setting weights to 0.0";
|
||||
weights_.resize(weights_.size() + lmPaths.size(), 0.0);
|
||||
}
|
||||
//if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
||||
// // this should be a warning
|
||||
// LOG(info) << "More KenLM models than weights, setting weights to 0.0";
|
||||
// weights_.resize(weights_.size() + lmPaths.size(), 0.0);
|
||||
//}
|
||||
|
||||
if(Has("load-weights")) {
|
||||
LoadWeights(Get<std::string>("load-weights"));
|
||||
@ -157,6 +71,7 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto devices = Get<std::vector<size_t>>("devices");
|
||||
modelsPerDevice_.resize(devices.size());
|
||||
{
|
||||
ThreadPool devicePool(devices.size());
|
||||
@ -171,10 +86,10 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& lmPath : lmPaths) {
|
||||
LOG(info) << "Loading lm " << lmPath;
|
||||
lms_.emplace_back(lmPath, *targetVocab_);
|
||||
}
|
||||
//for(auto& lmPath : lmPaths) {
|
||||
// LOG(info) << "Loading lm " << lmPath;
|
||||
// lms_.emplace_back(lmPath, *targetVocab_);
|
||||
//}
|
||||
|
||||
return *this;
|
||||
}
|
||||
@ -230,34 +145,3 @@ void God::LoadWeights(const std::string& path) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
void God::PrintConfig() {
|
||||
LOG(info) << "Options set: ";
|
||||
for(auto& entry: instance_.vm_) {
|
||||
std::stringstream ss;
|
||||
ss << "\t" << entry.first << " = ";
|
||||
try {
|
||||
for(auto& v : entry.second.as<std::vector<std::string>>())
|
||||
ss << v << " ";
|
||||
} catch(...) { }
|
||||
try {
|
||||
for(auto& v : entry.second.as<std::vector<float>>())
|
||||
ss << v << " ";
|
||||
} catch(...) { }
|
||||
try {
|
||||
for(auto& v : entry.second.as<std::vector<size_t>>())
|
||||
ss << v << " ";
|
||||
} catch(...) { }
|
||||
try {
|
||||
ss << entry.second.as<std::string>();
|
||||
} catch(...) { }
|
||||
try {
|
||||
ss << entry.second.as<bool>() ? "true" : "false";
|
||||
} catch(...) { }
|
||||
try {
|
||||
ss << entry.second.as<size_t>();
|
||||
} catch(...) { }
|
||||
|
||||
LOG(info) << ss.str();
|
||||
}
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "config.h"
|
||||
#include "types.h"
|
||||
#include "vocab.h"
|
||||
#include "scorer.h"
|
||||
@ -9,8 +9,6 @@
|
||||
|
||||
// this should not be here
|
||||
#include "kenlm.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
class Weights;
|
||||
|
||||
@ -25,12 +23,12 @@ class God {
|
||||
}
|
||||
|
||||
static bool Has(const std::string& key) {
|
||||
return instance_.vm_.count(key) > 0;
|
||||
return Summon().config_.Has(key);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T Get(const std::string& key) {
|
||||
return instance_.vm_[key].as<T>();
|
||||
return Summon().config_.Get<T>(key);
|
||||
}
|
||||
|
||||
static Vocab& GetSourceVocab(size_t i = 0);
|
||||
@ -40,7 +38,6 @@ class God {
|
||||
static std::vector<size_t>& GetTabMap();
|
||||
|
||||
static void CleanUp();
|
||||
static void PrintConfig();
|
||||
|
||||
void LoadWeights(const std::string& path);
|
||||
|
||||
@ -48,7 +45,7 @@ class God {
|
||||
God& NonStaticInit(int argc, char** argv);
|
||||
|
||||
static God instance_;
|
||||
po::variables_map vm_;
|
||||
Config config_;
|
||||
|
||||
std::vector<std::unique_ptr<Vocab>> sourceVocabs_;
|
||||
std::unique_ptr<Vocab> targetVocab_;
|
||||
|
@ -15,7 +15,10 @@ class Encoder {
|
||||
|
||||
void Lookup(mblas::Matrix& Row, size_t i) {
|
||||
using namespace mblas;
|
||||
CopyRow(Row, w_.E_, i);
|
||||
if(i < w_.E_.Rows())
|
||||
CopyRow(Row, w_.E_, i);
|
||||
else
|
||||
CopyRow(Row, w_.E_, 1); // UNK
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -13,7 +13,7 @@ class SlowGRU {
|
||||
const mblas::Matrix& Context) const {
|
||||
using namespace mblas;
|
||||
|
||||
const size_t cols = State.Cols();
|
||||
const size_t cols = GetStateLength();
|
||||
|
||||
// @TODO: Optimization
|
||||
// @TODO: Launch streams to perform GEMMs in parallel
|
||||
@ -76,11 +76,11 @@ class FastGRU {
|
||||
public:
|
||||
FastGRU(const Weights& model)
|
||||
: w_(model) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
/*for(int i = 0; i < 4; ++i) {
|
||||
cudaStreamCreate(&s_[i]);
|
||||
cublasCreate(&h_[i]);
|
||||
cublasSetStream(h_[i], s_[i]);
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
void GetNextState(mblas::Matrix& NextState,
|
||||
@ -88,7 +88,7 @@ class FastGRU {
|
||||
const mblas::Matrix& Context) const {
|
||||
using namespace mblas;
|
||||
|
||||
const size_t cols = State.Cols();
|
||||
const size_t cols = GetStateLength();
|
||||
|
||||
// @TODO: Optimization
|
||||
// @TODO: Launch streams to perform GEMMs in parallel
|
||||
|
@ -171,28 +171,53 @@ Matrix& Prod(cublasHandle_t handle, Matrix& C, const Matrix& A, const Matrix& B,
|
||||
Matrix::value_type alpha = 1.0;
|
||||
Matrix::value_type beta = 0.0;
|
||||
|
||||
//size_t m = A.Rows();
|
||||
//size_t k = A.Cols();
|
||||
////if(transA)
|
||||
//// std::swap(m, k);
|
||||
//
|
||||
//size_t l = B.Rows();
|
||||
//size_t n = B.Cols();
|
||||
////if(transB)
|
||||
//// std::swap(l, n);
|
||||
//
|
||||
//C.Resize(m, n);
|
||||
//
|
||||
//size_t lda = A.Cols();
|
||||
//size_t ldb = B.Cols();
|
||||
//size_t ldc = C.Cols();
|
||||
//
|
||||
//nervana_sgemm(const_cast<float*>(A.data()),
|
||||
// const_cast<float*>(B.data()),
|
||||
// C.data(),
|
||||
// transA, transB,
|
||||
// m, n, k,
|
||||
// lda, ldb, ldc,
|
||||
// alpha, beta,
|
||||
// 0, false, false, 0);
|
||||
|
||||
size_t m = A.Rows();
|
||||
size_t k = A.Cols();
|
||||
if(transA)
|
||||
std::swap(m, k);
|
||||
|
||||
|
||||
size_t l = B.Rows();
|
||||
size_t n = B.Cols();
|
||||
if(transB)
|
||||
std::swap(l, n);
|
||||
|
||||
|
||||
size_t lda = A.Cols();
|
||||
size_t ldb = B.Cols();
|
||||
size_t ldc = B.Cols();
|
||||
|
||||
|
||||
if(transB)
|
||||
ldc = B.Rows();
|
||||
|
||||
|
||||
C.Resize(m, n);
|
||||
|
||||
|
||||
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
|
||||
cublasSgemm(handle, opB, opA,
|
||||
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
|
||||
return C;
|
||||
|
@ -11,6 +11,9 @@
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/functional.h>
|
||||
|
||||
//#include "nervana_c_api.h"
|
||||
|
||||
|
||||
#include "thrust_functions.h"
|
||||
|
||||
namespace lib = thrust;
|
||||
|
Loading…
Reference in New Issue
Block a user