mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
yaml configuration for models, moved config out of god class, individual paramerters per model
This commit is contained in:
parent
7e76b61f88
commit
b80ec4fa04
@ -24,7 +24,6 @@ if (YAMLCPP_FOUND)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${YAMLCPP_LIBRARY})
|
||||
endif (YAMLCPP_FOUND)
|
||||
|
||||
|
||||
set(KENLM CACHE STRING "Path to compiled kenlm directory")
|
||||
if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
|
||||
message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
|
||||
|
30
scripts/dropout.py
Executable file
30
scripts/dropout.py
Executable file
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os, sys
|
||||
import argparse
|
||||
import numpy as np;
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--dropout', type=float, required=True,
|
||||
help="dropout rate")
|
||||
parser.add_argument('-i', '--input', required=True,
|
||||
help="Input model")
|
||||
parser.add_argument('-o', '--output', required=True,
|
||||
help="Output model")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
multiplier = 1.0 - args.dropout
|
||||
|
||||
output = dict()
|
||||
print "Loading", args.input, "to multiple with", multiplier
|
||||
with open(args.input, "rb") as mfile:
|
||||
m = np.load(mfile)
|
||||
for k in m:
|
||||
if "history_errs" in k or "_b" in k or "c_tt" in k:
|
||||
output[k] = m[k]
|
||||
else:
|
||||
output[k] = multiplier * m[k]
|
||||
|
||||
print "Saving to", args.output
|
||||
np.savez(args.output, **output)
|
16
scripts/idf.py
Normal file
16
scripts/idf.py
Normal file
@ -0,0 +1,16 @@
|
||||
import sys
|
||||
import math
|
||||
from collections import Counter
|
||||
|
||||
c = Counter()
|
||||
N = 0
|
||||
for line in sys.stdin:
|
||||
uniq = set(line.split())
|
||||
for word in uniq:
|
||||
c[word] += 1
|
||||
N += 1
|
||||
|
||||
keys = sorted([k for k in c])
|
||||
for word in keys:
|
||||
idf = math.log(float(N) / float(c[word])) / math.log(N)
|
||||
print word, ":", idf
|
@ -8,6 +8,7 @@ include_directories(mblas)
|
||||
add_library(libcommon OBJECT
|
||||
cnpy/cnpy.cpp
|
||||
common/utils.cpp
|
||||
common/exception.cpp
|
||||
)
|
||||
|
||||
add_library(librescorer OBJECT
|
||||
|
104
src/common/exception.cpp
Normal file
104
src/common/exception.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
#include "util/exception.hh"
|
||||
|
||||
#ifdef __GXX_RTTI
|
||||
#include <typeinfo>
|
||||
#endif
|
||||
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#endif
|
||||
|
||||
namespace util {
|
||||
|
||||
Exception::Exception() throw() {}
|
||||
Exception::~Exception() throw() {}
|
||||
|
||||
void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
|
||||
/* The child class might have set some text, but we want this to come first.
|
||||
* Another option would be passing this information to the constructor, but
|
||||
* then child classes would have to accept constructor arguments and pass
|
||||
* them down.
|
||||
*/
|
||||
std::string old_text;
|
||||
what_.swap(old_text);
|
||||
what_ << file << ':' << line;
|
||||
if (func) what_ << " in " << func << " threw ";
|
||||
if (child_name) {
|
||||
what_ << child_name;
|
||||
} else {
|
||||
#ifdef __GXX_RTTI
|
||||
what_ << typeid(this).name();
|
||||
#else
|
||||
what_ << "an exception";
|
||||
#endif
|
||||
}
|
||||
if (condition) {
|
||||
what_ << " because `" << condition << '\'';
|
||||
}
|
||||
what_ << ".\n";
|
||||
what_ << old_text;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef __GNUC__
|
||||
const char *HandleStrerror(int ret, const char *buf) __attribute__ ((unused));
|
||||
const char *HandleStrerror(const char *ret, const char * /*buf*/) __attribute__ ((unused));
|
||||
#endif
|
||||
// At least one of these functions will not be called.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#endif
|
||||
// The XOPEN version.
|
||||
const char *HandleStrerror(int ret, const char *buf) {
|
||||
if (!ret) return buf;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// The GNU version.
|
||||
const char *HandleStrerror(const char *ret, const char * /*buf*/) {
|
||||
return ret;
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
ErrnoException::ErrnoException() throw() : errno_(errno) {
|
||||
char buf[200];
|
||||
buf[0] = 0;
|
||||
#if defined(sun) || defined(_WIN32) || defined(_WIN64)
|
||||
const char *add = strerror(errno);
|
||||
#else
|
||||
const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
|
||||
#endif
|
||||
|
||||
if (add) {
|
||||
*this << add << ' ';
|
||||
}
|
||||
}
|
||||
|
||||
ErrnoException::~ErrnoException() throw() {}
|
||||
|
||||
OverflowException::OverflowException() throw() {}
|
||||
OverflowException::~OverflowException() throw() {}
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
WindowsException::WindowsException() throw() {
|
||||
unsigned int last_error = GetLastError();
|
||||
char error_msg[256] = "";
|
||||
if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, last_error, LANG_NEUTRAL, error_msg, sizeof(error_msg), NULL)) {
|
||||
*this << "Windows error " << GetLastError() << " while formatting Windows error " << last_error << ". ";
|
||||
} else {
|
||||
*this << "Windows error " << last_error << ": " << error_msg;
|
||||
}
|
||||
}
|
||||
WindowsException::~WindowsException() throw() {}
|
||||
#endif
|
||||
|
||||
} // namespace util
|
156
src/common/exception.h
Normal file
156
src/common/exception.h
Normal file
@ -0,0 +1,156 @@
|
||||
#pragma once
|
||||
|
||||
#include "util/string_stream.hh"
|
||||
|
||||
#include <exception>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace util {
|
||||
|
||||
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
|
||||
|
||||
class Exception : public std::exception {
|
||||
public:
|
||||
Exception() throw();
|
||||
virtual ~Exception() throw();
|
||||
|
||||
const char *what() const throw() { return what_.str().c_str(); }
|
||||
|
||||
// For use by the UTIL_THROW macros.
|
||||
void SetLocation(
|
||||
const char *file,
|
||||
unsigned int line,
|
||||
const char *func,
|
||||
const char *child_name,
|
||||
const char *condition);
|
||||
|
||||
private:
|
||||
template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
|
||||
|
||||
// This helps restrict operator<< defined below.
|
||||
template <class T> struct ExceptionTag {
|
||||
typedef T Identity;
|
||||
};
|
||||
|
||||
StringStream what_;
|
||||
};
|
||||
|
||||
/* This implements the normal operator<< for Exception and all its children.
|
||||
* SFINAE means it only applies to Exception. Think of this as an ersatz
|
||||
* boost::enable_if.
|
||||
*/
|
||||
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
|
||||
e.what_ << data;
|
||||
return e;
|
||||
}
|
||||
|
||||
#ifdef __GNUC__
|
||||
#define UTIL_FUNC_NAME __PRETTY_FUNCTION__
|
||||
#else
|
||||
#ifdef _WIN32
|
||||
#define UTIL_FUNC_NAME __FUNCTION__
|
||||
#else
|
||||
#define UTIL_FUNC_NAME NULL
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/* Create an instance of Exception, add the message Modify, and throw it.
|
||||
* Modify is appended to the what() message and can contain << for ostream
|
||||
* operations.
|
||||
*
|
||||
* do .. while kludge to swallow trailing ; character
|
||||
* http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .
|
||||
* Arg can be a constructor argument to the exception.
|
||||
*/
|
||||
#define UTIL_THROW_BACKEND(Condition, Exception, Arg, Modify) do { \
|
||||
Exception UTIL_e Arg; \
|
||||
UTIL_e.SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, #Exception, Condition); \
|
||||
UTIL_e << Modify; \
|
||||
throw UTIL_e; \
|
||||
} while (0)
|
||||
|
||||
#define UTIL_THROW_ARG(Exception, Arg, Modify) \
|
||||
UTIL_THROW_BACKEND(NULL, Exception, Arg, Modify)
|
||||
|
||||
#define UTIL_THROW(Exception, Modify) \
|
||||
UTIL_THROW_BACKEND(NULL, Exception, , Modify);
|
||||
|
||||
#define UTIL_THROW2(Modify) \
|
||||
UTIL_THROW_BACKEND(NULL, util::Exception, , Modify);
|
||||
|
||||
#if __GNUC__ >= 3
|
||||
#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0)
|
||||
#else
|
||||
#define UTIL_UNLIKELY(x) (x)
|
||||
#endif
|
||||
|
||||
#if __GNUC__ >= 3
|
||||
#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1)
|
||||
#else
|
||||
#define UTIL_LIKELY(x) (x)
|
||||
#endif
|
||||
|
||||
#define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \
|
||||
if (UTIL_UNLIKELY(Condition)) { \
|
||||
UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define UTIL_THROW_IF(Condition, Exception, Modify) \
|
||||
UTIL_THROW_IF_ARG(Condition, Exception, , Modify)
|
||||
|
||||
#define UTIL_THROW_IF2(Condition, Modify) \
|
||||
UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify)
|
||||
|
||||
// Exception that records errno and adds it to the message.
|
||||
class ErrnoException : public Exception {
|
||||
public:
|
||||
ErrnoException() throw();
|
||||
|
||||
virtual ~ErrnoException() throw();
|
||||
|
||||
int Error() const throw() { return errno_; }
|
||||
|
||||
private:
|
||||
int errno_;
|
||||
};
|
||||
|
||||
// file wasn't there, or couldn't be open for some reason
|
||||
class FileOpenException : public Exception {
|
||||
public:
|
||||
FileOpenException() throw() {}
|
||||
~FileOpenException() throw() {}
|
||||
};
|
||||
|
||||
// Utilities for overflow checking.
|
||||
class OverflowException : public Exception {
|
||||
public:
|
||||
OverflowException() throw();
|
||||
~OverflowException() throw();
|
||||
};
|
||||
|
||||
template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
|
||||
UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
|
||||
return value;
|
||||
}
|
||||
|
||||
template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
inline std::size_t CheckOverflow(uint64_t value) {
|
||||
return CheckOverflowInternal<sizeof(std::size_t)>(value);
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
/* Thrown for Windows specific operations. */
|
||||
class WindowsException : public Exception {
|
||||
public:
|
||||
WindowsException() throw();
|
||||
~WindowsException() throw();
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace util
|
@ -7,7 +7,7 @@
|
||||
#include "matrix.h"
|
||||
|
||||
class ApePenaltyState : public State {
|
||||
// Dummy
|
||||
// Dummy, this scorer is stateless
|
||||
};
|
||||
|
||||
class ApePenalty : public Scorer {
|
||||
@ -17,6 +17,7 @@ class ApePenalty : public Scorer {
|
||||
: Scorer(sourceIndex)
|
||||
{ }
|
||||
|
||||
// @TODO: make this work on GPU
|
||||
virtual void SetSource(const Sentence& source) {
|
||||
const Words& words = source.GetWords(sourceIndex_);
|
||||
const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
|
||||
@ -57,3 +58,18 @@ class ApePenalty : public Scorer {
|
||||
private:
|
||||
std::vector<float> costs_;
|
||||
};
|
||||
|
||||
class ApePenaltyLoader : public Loader {
|
||||
public:
|
||||
ApePenaltyLoader(const YAML::Node& config)
|
||||
: Loader(config) {}
|
||||
|
||||
virtual void Load() {
|
||||
// @TODO: IDF weights
|
||||
}
|
||||
|
||||
virtual ScorerPtr NewScorer(size_t taskId) {
|
||||
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
|
||||
return ScorerPtr(new ApePenalty(tab));
|
||||
}
|
||||
};
|
@ -1,6 +1,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "config.h"
|
||||
#include "exception.h"
|
||||
|
||||
#define SET_OPTION(key, type) \
|
||||
if(!vm_[key].defaulted() || !config_[key]) { \
|
||||
@ -26,35 +27,20 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
|
||||
std::string configPath;
|
||||
std::vector<size_t> devices;
|
||||
std::vector<size_t> tabMap;
|
||||
std::vector<float> weights;
|
||||
|
||||
std::vector<std::string> modelPaths;
|
||||
std::vector<std::string> lmPaths;
|
||||
std::vector<std::string> sourceVocabPaths;
|
||||
std::string targetVocabPath;
|
||||
|
||||
general.add_options()
|
||||
("config,c", po::value(&configPath),
|
||||
("config,c", po::value(&configPath)->required(),
|
||||
"Configuration file")
|
||||
("model,m", po::value(&modelPaths)->multitoken()->required(),
|
||||
"Path to neural translation model(s)")
|
||||
("source,s", po::value(&sourceVocabPaths)->multitoken()->required(),
|
||||
"Path to source vocabulary file.")
|
||||
("target,t", po::value(&targetVocabPath)->required(),
|
||||
"Path to target vocabulary file.")
|
||||
("ape", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Add APE-penalty")
|
||||
("lm,l", po::value(&lmPaths)->multitoken(),
|
||||
"Path to KenLM language model(s)")
|
||||
("tab-map", po::value(&tabMap)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"tab map")
|
||||
("devices,d", po::value(&devices)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
"CUDA device(s) to use, set to 0 by default, "
|
||||
"e.g. set to 0 1 to use gpu0 and gpu1. "
|
||||
"Implicitly sets minimal number of threads to number of devices.")
|
||||
("threads-per-device", po::value<size_t>()->default_value(1),
|
||||
"Number of threads per device, total thread count equals threads x devices")
|
||||
("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output used weights to stdout and exit")
|
||||
("load-weights", po::value<std::string>(),
|
||||
"Load scorer weights from this file")
|
||||
("help,h", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Print this help message and exit")
|
||||
;
|
||||
@ -67,26 +53,11 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
"Normalize scores by translation length after decoding")
|
||||
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output n-best list with n = beam-size")
|
||||
("weights,w", po::value(&weights)->multitoken()->default_value(std::vector<float>(1, 1.0), "1.0"),
|
||||
"Model weights (for neural models and KenLM models)")
|
||||
("show-weights", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Output used weights to stdout and exit")
|
||||
("load-weights", po::value<std::string>(),
|
||||
"Load scorer weights from this file")
|
||||
;
|
||||
|
||||
po::options_description kenlm("KenLM specific options");
|
||||
kenlm.add_options()
|
||||
("kenlm-batch-size", po::value<size_t>()->default_value(1000),
|
||||
"Batch size for batched queries to KenLM")
|
||||
("kenlm-batch-threads", po::value<size_t>()->default_value(4),
|
||||
"Concurrent worker threads for batch processing")
|
||||
;
|
||||
|
||||
po::options_description cmdline_options("Allowed options");
|
||||
cmdline_options.add(general);
|
||||
cmdline_options.add(search);
|
||||
cmdline_options.add(kenlm);
|
||||
|
||||
po::variables_map vm_;
|
||||
try {
|
||||
@ -108,32 +79,35 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(configPath.size())
|
||||
config_ = YAML::LoadFile(configPath);
|
||||
|
||||
SET_OPTION("model", std::vector<std::string>)
|
||||
SET_OPTION_NONDEFAULT("lm", std::vector<std::string>)
|
||||
SET_OPTION("ape", bool)
|
||||
SET_OPTION("source", std::vector<std::string>)
|
||||
SET_OPTION("target", std::string)
|
||||
|
||||
SET_OPTION("n-best", bool)
|
||||
SET_OPTION("normalize", bool)
|
||||
SET_OPTION("beam-size", size_t)
|
||||
SET_OPTION("threads-per-device", size_t)
|
||||
SET_OPTION("devices", std::vector<size_t>)
|
||||
SET_OPTION("tab-map", std::vector<size_t>)
|
||||
|
||||
SET_OPTION("weights", std::vector<float>)
|
||||
SET_OPTION("show-weights", bool)
|
||||
SET_OPTION_NONDEFAULT("load-weights", std::string)
|
||||
|
||||
SET_OPTION("kenlm-batch-size", size_t)
|
||||
SET_OPTION("kenlm-batch-threads", size_t)
|
||||
Validate();
|
||||
}
|
||||
|
||||
void Config::Validate() {
|
||||
UTIL_THROW_IF2(!config_["scorers"] || config_["scorers"].size() == 0,
|
||||
"No scorers given in config file");
|
||||
|
||||
UTIL_THROW_IF2(!config_["source-vocab"] || config_["source-vocab"].size() == 0,
|
||||
"No source-vocab given in config file");
|
||||
|
||||
UTIL_THROW_IF2(!config_["target-vocab"],
|
||||
"No target-vocab given in config file");
|
||||
|
||||
UTIL_THROW_IF2(config_["weights"].size() != config_["scorers"].size(),
|
||||
"Different number of models and weights in config file");
|
||||
}
|
||||
|
||||
void OutputRec(const YAML::Node node, YAML::Emitter& out) {
|
||||
std::set<std::string> flow = { "weights", "devices", "tab-map" };
|
||||
std::set<std::string> flow = { "weights", "devices"};
|
||||
std::set<std::string> sorter;
|
||||
switch (node.Type()) {
|
||||
case YAML::NodeType::Null:
|
||||
|
@ -27,5 +27,6 @@ class Config {
|
||||
return out;
|
||||
}
|
||||
|
||||
void Validate();
|
||||
void LogOptions();
|
||||
};
|
||||
|
@ -1,11 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#include "matrix.h"
|
||||
#include "scorer.h"
|
||||
#include "loader.h"
|
||||
#include "dl4mt.h"
|
||||
|
||||
#include "threadpool.h"
|
||||
|
||||
class EncoderDecoderState : public State {
|
||||
public:
|
||||
mblas::Matrix& GetStates() {
|
||||
@ -34,8 +38,8 @@ class EncoderDecoder : public Scorer {
|
||||
typedef EncoderDecoderState EDState;
|
||||
|
||||
public:
|
||||
EncoderDecoder(const Weights& model, size_t sourceIndex)
|
||||
: Scorer(sourceIndex), model_(model),
|
||||
EncoderDecoder(const Weights& model, size_t tabIndex)
|
||||
: Scorer(tabIndex), model_(model),
|
||||
encoder_(new Encoder(model_)), decoder_(new Decoder(model_))
|
||||
{}
|
||||
|
||||
@ -98,3 +102,33 @@ class EncoderDecoder : public Scorer {
|
||||
|
||||
mblas::Matrix SourceContext_;
|
||||
};
|
||||
|
||||
class EncoderDecoderLoader : public Loader {
|
||||
public:
|
||||
EncoderDecoderLoader(const YAML::Node& config)
|
||||
: Loader(config) {}
|
||||
|
||||
virtual void Load() {
|
||||
std::string path = Get<std::string>("path");
|
||||
auto devices = God::Get<std::vector<size_t>>("devices");
|
||||
ThreadPool devicePool(devices.size());
|
||||
for(auto d : devices) {
|
||||
devicePool.enqueue([d, &path, this] {
|
||||
LOG(info) << "Loading model " << path << " onto gpu" << d;
|
||||
cudaSetDevice(d);
|
||||
weights_.emplace_back(new Weights(path, d));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
virtual ScorerPtr NewScorer(size_t taskId) {
|
||||
size_t i = taskId % weights_.size();
|
||||
size_t d = weights_[i]->GetDevice();
|
||||
cudaSetDevice(d);
|
||||
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
|
||||
return ScorerPtr(new EncoderDecoder(*weights_[i], tab));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Weights>> weights_;
|
||||
};
|
||||
|
@ -7,9 +7,7 @@
|
||||
#include "config.h"
|
||||
#include "scorer.h"
|
||||
#include "threadpool.h"
|
||||
#include "encoder_decoder.h"
|
||||
#include "language_model.h"
|
||||
#include "ape_penalty.h"
|
||||
#include "loader_factory.h"
|
||||
|
||||
God God::instance_;
|
||||
|
||||
@ -27,37 +25,11 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
config_.AddOptions(argc, argv);
|
||||
config_.LogOptions();
|
||||
|
||||
for(auto sourceVocabPath : Get<std::vector<std::string>>("source"))
|
||||
for(auto sourceVocabPath : Get<std::vector<std::string>>("source-vocab"))
|
||||
sourceVocabs_.emplace_back(new Vocab(sourceVocabPath));
|
||||
targetVocab_.reset(new Vocab(Get<std::string>("target")));
|
||||
targetVocab_.reset(new Vocab(Get<std::string>("target-vocab")));
|
||||
|
||||
auto modelPaths = Get<std::vector<std::string>>("model");
|
||||
|
||||
tabMap_ = Get<std::vector<size_t>>("tab-map");
|
||||
if(tabMap_.size() < modelPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More neural models than tabs, setting missing tabs to 0";
|
||||
tabMap_.resize(modelPaths.size(), 0);
|
||||
}
|
||||
|
||||
// @TODO: handle this better!
|
||||
weights_ = Get<std::vector<float>>("weights");
|
||||
if(weights_.size() < modelPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More neural models than weights, setting weights to 1.0";
|
||||
weights_.resize(modelPaths.size(), 1.0);
|
||||
}
|
||||
|
||||
if(Get<bool>("ape") && weights_.size() < modelPaths.size() + 1) {
|
||||
LOG(info) << "Adding weight for APE-penalty: " << 1.0;
|
||||
weights_.resize(modelPaths.size(), 1.0);
|
||||
}
|
||||
|
||||
//if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
||||
// // this should be a warning
|
||||
// LOG(info) << "More KenLM models than weights, setting weights to 0.0";
|
||||
// weights_.resize(weights_.size() + lmPaths.size(), 0.0);
|
||||
//}
|
||||
|
||||
if(Has("load-weights")) {
|
||||
LoadWeights(Get<std::string>("load-weights"));
|
||||
@ -71,25 +43,8 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto devices = Get<std::vector<size_t>>("devices");
|
||||
modelsPerDevice_.resize(devices.size());
|
||||
{
|
||||
ThreadPool devicePool(devices.size());
|
||||
for(auto& modelPath : modelPaths) {
|
||||
for(size_t i = 0; i < devices.size(); ++i) {
|
||||
devicePool.enqueue([i, &devices, &modelPath, this]{
|
||||
LOG(info) << "Loading model " << modelPath << " onto gpu" << devices[i];
|
||||
cudaSetDevice(devices[i]);
|
||||
modelsPerDevice_[i].emplace_back(new Weights(modelPath, devices[i]));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//for(auto& lmPath : lmPaths) {
|
||||
// LOG(info) << "Loading lm " << lmPath;
|
||||
// lms_.emplace_back(lmPath, *targetVocab_);
|
||||
//}
|
||||
for(auto&& modelConfig : config_.Get()["scorers"])
|
||||
loaders_.emplace_back(LoaderFactory::Create(modelConfig));
|
||||
|
||||
return *this;
|
||||
}
|
||||
@ -102,18 +57,10 @@ Vocab& God::GetTargetVocab() {
|
||||
return *Summon().targetVocab_;
|
||||
}
|
||||
|
||||
std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
|
||||
size_t deviceId = threadId % Summon().modelsPerDevice_.size();
|
||||
size_t device = Summon().modelsPerDevice_[deviceId][0]->GetDevice();
|
||||
cudaSetDevice(device);
|
||||
std::vector<ScorerPtr> God::GetScorers(size_t taskId) {
|
||||
std::vector<ScorerPtr> scorers;
|
||||
size_t i = 0;
|
||||
for(auto& m : Summon().modelsPerDevice_[deviceId])
|
||||
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
|
||||
if(God::Get<bool>("ape"))
|
||||
scorers.emplace_back(new ApePenalty(Summon().tabMap_[i++]));
|
||||
for(auto& lm : Summon().lms_)
|
||||
scorers.emplace_back(new LanguageModel(lm));
|
||||
for(auto&& loader : Summon().loaders_)
|
||||
scorers.emplace_back(loader->NewScorer(taskId));
|
||||
return scorers;
|
||||
}
|
||||
|
||||
@ -121,15 +68,10 @@ std::vector<float>& God::GetScorerWeights() {
|
||||
return Summon().weights_;
|
||||
}
|
||||
|
||||
std::vector<size_t>& God::GetTabMap() {
|
||||
return Summon().tabMap_;
|
||||
}
|
||||
|
||||
// clean up cuda vectors before cuda context goes out of scope
|
||||
void God::CleanUp() {
|
||||
for(auto& models : Summon().modelsPerDevice_)
|
||||
for(auto& m : models)
|
||||
m.reset(nullptr);
|
||||
for(auto& loader : Summon().loaders_)
|
||||
loader.reset(nullptr);
|
||||
}
|
||||
|
||||
void God::LoadWeights(const std::string& path) {
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "config.h"
|
||||
#include "types.h"
|
||||
#include "vocab.h"
|
||||
#include "loader.h"
|
||||
#include "scorer.h"
|
||||
#include "logging.h"
|
||||
|
||||
@ -50,16 +51,9 @@ class God {
|
||||
std::vector<std::unique_ptr<Vocab>> sourceVocabs_;
|
||||
std::unique_ptr<Vocab> targetVocab_;
|
||||
|
||||
typedef std::unique_ptr<Weights> Model;
|
||||
typedef std::vector<Model> Models;
|
||||
typedef std::vector<Models> ModelsPerDevice;
|
||||
std::vector<LoaderPtr> loaders_;
|
||||
|
||||
ModelsPerDevice modelsPerDevice_;
|
||||
std::vector<LM> lms_;
|
||||
|
||||
std::vector<ScorerPtr> scorers_;
|
||||
std::vector<float> weights_;
|
||||
std::vector<size_t> tabMap_;
|
||||
|
||||
std::shared_ptr<spdlog::logger> info_;
|
||||
std::shared_ptr<spdlog::logger> progress_;
|
||||
|
32
src/decoder/loader.h
Normal file
32
src/decoder/loader.h
Normal file
@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#include "scorer.h"
|
||||
|
||||
class Loader {
|
||||
public:
|
||||
Loader(const YAML::Node& config)
|
||||
: config_(YAML::Clone(config)) {}
|
||||
|
||||
virtual ~Loader() {};
|
||||
|
||||
virtual void Load() = 0;
|
||||
|
||||
bool Has(const std::string& key) {
|
||||
return config_[key];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Get(const std::string& key) {
|
||||
return config_[key].as<T>();
|
||||
}
|
||||
|
||||
virtual ScorerPtr NewScorer(size_t) = 0;
|
||||
|
||||
protected:
|
||||
const YAML::Node config_;
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<Loader> LoaderPtr;
|
46
src/decoder/loader_factory.h
Normal file
46
src/decoder/loader_factory.h
Normal file
@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#include "exception.h"
|
||||
#include "scorer.h"
|
||||
#include "encoder_decoder.h"
|
||||
#include "ape_penalty.h"
|
||||
|
||||
#ifdef KENLM
|
||||
#include "language_model.h"
|
||||
#endif
|
||||
|
||||
#define IF_MATCH_RETURN(typeStr, nameStr, LoaderType) \
|
||||
do { \
|
||||
if(typeStr == nameStr) { \
|
||||
LoaderPtr loader(new LoaderType(config)); \
|
||||
loader->Load(); \
|
||||
return loader; \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
class LoaderFactory {
|
||||
public:
|
||||
static LoaderPtr Create(const YAML::Node& config) {
|
||||
UTIL_THROW_IF2(!config["type"],
|
||||
"Missing scorer type in config file");
|
||||
|
||||
auto type = config["type"].as<std::string>();
|
||||
IF_MATCH_RETURN(type, "Nematus", EncoderDecoderLoader);
|
||||
IF_MATCH_RETURN(type, "nematus", EncoderDecoderLoader);
|
||||
IF_MATCH_RETURN(type, "NEMATUS", EncoderDecoderLoader);
|
||||
|
||||
IF_MATCH_RETURN(type, "Ape", ApePenaltyLoader);
|
||||
IF_MATCH_RETURN(type, "ape", ApePenaltyLoader);
|
||||
IF_MATCH_RETURN(type, "APE", ApePenaltyLoader);
|
||||
#ifdef KENLM
|
||||
IF_MATCH_RETURN(type, "KenLM", KenLMLoader)
|
||||
IF_MATCH_RETURN(type, "kenlm", KenLMLoader)
|
||||
IF_MATCH_RETURN(type, "KENLM", KenLMLoader)
|
||||
#endif
|
||||
UTIL_THROW2("Unknown scorer in config file: " << type);
|
||||
}
|
||||
};
|
||||
|
@ -5,13 +5,11 @@
|
||||
Sentence::Sentence(size_t lineNo, const std::string& line)
|
||||
: lineNo_(lineNo), line_(line)
|
||||
{
|
||||
auto& tabMap = God::GetTabMap();
|
||||
|
||||
std::vector<std::string> tabs;
|
||||
Split(line, tabs, "\t");
|
||||
size_t i = 0;
|
||||
for(auto&& tab : tabs)
|
||||
words_.push_back(God::GetSourceVocab(tabMap[i++])(tab));
|
||||
words_.push_back(God::GetSourceVocab(i++)(tab));
|
||||
}
|
||||
|
||||
const Words& Sentence::GetWords(size_t index) const {
|
||||
|
Loading…
Reference in New Issue
Block a user