mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Scorer model loading (#860)
* Add MMAP as an option * Use io::isBin * Allow getYamlFromModel from an Item vector * ScorerWrapper can now load on to a graph from Item vector The interface IEncoderDecoder can now call graph loads directly from an Item Vector. * Translator loads model before creating scorers Scorers are created from an Item vector * Replace model-config try-catch with check using IsNull * Prefer empty vs size * load by items should be pure virtual * Stepwise forward load to encdec * nematus can load from items * amun can load from items * loadItems in TranslateService * Remove logging * Remove by filename scorer functions * Replace by filename createScorer * Explicitly provide default value for get model-mmap * CLI option for model-mmap only for translation and CPU compile * Ensure model-mmap option is CPU only * Remove move on temporary object * Reinstate log messages for model loading in Amun / Nematus * Add log messages for model loading in scorers Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
This commit is contained in:
parent
c84599d08a
commit
b29cc07a95
@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
|
||||
- Add unit tests for binary files.
|
||||
- Fix compilation with OMP
|
||||
- Added `--model-mmap` option to enable mmap loading for CPU-based translation
|
||||
- Compute aligned memory sizes using exact sizing
|
||||
- Support for loading lexical shortlist from a binary blob
|
||||
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
|
||||
|
@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
"Path prefix for pre-trained model to initialize model weights");
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef COMPILE_CPU
|
||||
if(mode_ == cli::mode::translation) {
|
||||
cli.add<bool>("--model-mmap",
|
||||
"Use memory-mapping when loading model (CPU only)");
|
||||
}
|
||||
#endif
|
||||
cli.add<bool>("--ignore-model-config",
|
||||
"Ignore the model configuration saved in npz file");
|
||||
cli.add<std::string>("--type",
|
||||
|
@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const {
|
||||
ABORT_IF(models.empty() && configs.empty(),
|
||||
"You need to provide at least one model file or a config file");
|
||||
|
||||
ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
|
||||
"Model MMAP is CPU-only, please use --cpu-threads");
|
||||
|
||||
for(const auto& modelFile : models) {
|
||||
filesystem::Path modelPath(modelFile);
|
||||
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
|
||||
|
@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml,
|
||||
yaml = YAML::Load(item.data());
|
||||
}
|
||||
|
||||
// Load YAML from item
|
||||
void getYamlFromModel(YAML::Node& yaml,
|
||||
const std::string& varName,
|
||||
const std::vector<Item>& items) {
|
||||
for(auto& item : items) {
|
||||
if(item.name == varName) {
|
||||
yaml = YAML::Load(item.data());
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void addMetaToItems(const std::string& meta,
|
||||
const std::string& varName,
|
||||
std::vector<io::Item>& items) {
|
||||
|
@ -21,6 +21,7 @@ bool isBin(const std::string& fileName);
|
||||
|
||||
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
|
||||
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
|
||||
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector<Item>& items);
|
||||
|
||||
void addMetaToItems(const std::string& meta,
|
||||
const std::string& varName,
|
||||
|
@ -739,7 +739,7 @@ public:
|
||||
|
||||
public:
|
||||
/** Load model (mainly parameter objects) from array of io::Items */
|
||||
void load(std::vector<io::Item>& ioItems, bool markReloaded = true) {
|
||||
void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
|
||||
setReloaded(false);
|
||||
for(auto& item : ioItems) {
|
||||
std::string pName = item.name;
|
||||
|
@ -36,7 +36,7 @@ public:
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
const std::vector<io::Item>& items,
|
||||
bool /*markedReloaded*/ = true) override {
|
||||
std::map<std::string, std::string> nameMap
|
||||
= {{"decoder_U", "decoder_cell1_U"},
|
||||
@ -89,9 +89,7 @@ public:
|
||||
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
|
||||
nameMap["Wemb"] = "Wemb";
|
||||
|
||||
LOG(info, "Loading model from {}", name);
|
||||
// load items from .npz file
|
||||
auto ioItems = io::loadItems(name);
|
||||
auto ioItems = items;
|
||||
// map names and remove a dummy matrices
|
||||
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
||||
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
||||
@ -120,6 +118,14 @@ public:
|
||||
graph->load(ioItems);
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool /*markReloaded*/ = true) override {
|
||||
LOG(info, "Loading model from {}", name);
|
||||
auto ioItems = io::loadItems(name);
|
||||
load(graph, ioItems);
|
||||
}
|
||||
|
||||
void save(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool saveTranslatorConfig = false) override {
|
||||
|
@ -325,6 +325,12 @@ protected:
|
||||
public:
|
||||
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded = true) override {
|
||||
encdec_->load(graph, items, markedReloaded);
|
||||
}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override {
|
||||
|
@ -144,6 +144,12 @@ std::string EncoderDecoder::getModelParametersAsString() {
|
||||
return std::string(out.c_str());
|
||||
}
|
||||
|
||||
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded) {
|
||||
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||
}
|
||||
|
||||
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded) {
|
||||
|
@ -12,6 +12,11 @@ namespace marian {
|
||||
class IEncoderDecoder : public models::IModel {
|
||||
public:
|
||||
virtual ~IEncoderDecoder() {}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded = true) = 0;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override
|
||||
@ -91,6 +96,10 @@ public:
|
||||
|
||||
void push_back(Ptr<DecoderBase> decoder);
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded = true) override;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override;
|
||||
|
@ -26,11 +26,9 @@ public:
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
const std::vector<io::Item>& items,
|
||||
bool /*markReloaded*/ = true) override {
|
||||
LOG(info, "Loading model from {}", name);
|
||||
// load items from .npz file
|
||||
auto ioItems = io::loadItems(name);
|
||||
auto ioItems = items;
|
||||
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
|
||||
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
||||
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
||||
@ -59,6 +57,14 @@ public:
|
||||
graph->load(ioItems);
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool /*markReloaded*/ = true) override {
|
||||
LOG(info, "Loading model from {}", name);
|
||||
auto ioItems = io::loadItems(name);
|
||||
load(graph, ioItems);
|
||||
}
|
||||
|
||||
void save(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool saveTranslatorConfig = false) override {
|
||||
|
@ -5,7 +5,7 @@ namespace marian {
|
||||
|
||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
float weight,
|
||||
const std::string& model,
|
||||
std::vector<io::Item> items,
|
||||
Ptr<Options> options) {
|
||||
options->set("inference", true);
|
||||
std::string type = options->get<std::string>("type");
|
||||
@ -22,7 +22,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
|
||||
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
|
||||
|
||||
return New<ScorerWrapper>(encdec, fname, weight, model);
|
||||
return New<ScorerWrapper>(encdec, fname, weight, items);
|
||||
}
|
||||
|
||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
@ -47,31 +47,31 @@ Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
return New<ScorerWrapper>(encdec, fname, weight, ptr);
|
||||
}
|
||||
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models) {
|
||||
std::vector<Ptr<Scorer>> scorers;
|
||||
|
||||
auto models = options->get<std::vector<std::string>>("models");
|
||||
|
||||
std::vector<float> weights(models.size(), 1.f);
|
||||
if(options->hasAndNotEmpty("weights"))
|
||||
weights = options->get<std::vector<float>>("weights");
|
||||
|
||||
bool isPrevRightLeft = false; // if the previous model was a right-to-left model
|
||||
size_t i = 0;
|
||||
for(auto model : models) {
|
||||
for(auto items : models) {
|
||||
std::string fname = "F" + std::to_string(i);
|
||||
|
||||
// load options specific for the scorer
|
||||
auto modelOptions = New<Options>(options->clone());
|
||||
try {
|
||||
if(!options->get<bool>("ignore-model-config")) {
|
||||
YAML::Node modelYaml;
|
||||
io::getYamlFromModel(modelYaml, "special:model.yml", model);
|
||||
io::getYamlFromModel(modelYaml, "special:model.yml", items);
|
||||
if(!modelYaml.IsNull()) {
|
||||
LOG(info, "Loaded model config");
|
||||
modelOptions->merge(modelYaml, true);
|
||||
}
|
||||
} catch(std::runtime_error&) {
|
||||
else {
|
||||
LOG(warn, "No model settings found in model file");
|
||||
}
|
||||
}
|
||||
|
||||
// l2r and r2l cannot be used in the same ensemble
|
||||
if(models.size() > 1 && modelOptions->has("right-left")) {
|
||||
@ -85,13 +85,24 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
||||
}
|
||||
}
|
||||
|
||||
scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
|
||||
scorers.push_back(scorerByType(fname, weights[i], items, modelOptions));
|
||||
i++;
|
||||
}
|
||||
|
||||
return scorers;
|
||||
}
|
||||
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
||||
std::vector<std::vector<io::Item>> model_items;
|
||||
auto models = options->get<std::vector<std::string>>("models");
|
||||
for(auto model : models) {
|
||||
auto items = io::loadItems(model);
|
||||
model_items.push_back(std::move(items));
|
||||
}
|
||||
|
||||
return createScorers(options, model_items);
|
||||
}
|
||||
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
|
||||
std::vector<Ptr<Scorer>> scorers;
|
||||
|
||||
@ -105,15 +116,17 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
|
||||
|
||||
// load options specific for the scorer
|
||||
auto modelOptions = New<Options>(options->clone());
|
||||
try {
|
||||
if(!options->get<bool>("ignore-model-config")) {
|
||||
YAML::Node modelYaml;
|
||||
io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
|
||||
if(!modelYaml.IsNull()) {
|
||||
LOG(info, "Loaded model config");
|
||||
modelOptions->merge(modelYaml, true);
|
||||
}
|
||||
} catch(std::runtime_error&) {
|
||||
else {
|
||||
LOG(warn, "No model settings found in model file");
|
||||
}
|
||||
}
|
||||
|
||||
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));
|
||||
i++;
|
||||
|
@ -73,9 +73,19 @@ class ScorerWrapper : public Scorer {
|
||||
private:
|
||||
Ptr<IEncoderDecoder> encdec_;
|
||||
std::string fname_;
|
||||
std::vector<io::Item> items_;
|
||||
const void* ptr_;
|
||||
|
||||
public:
|
||||
ScorerWrapper(Ptr<models::IModel> encdec,
|
||||
const std::string& name,
|
||||
float weight,
|
||||
std::vector<io::Item>& items)
|
||||
: Scorer(name, weight),
|
||||
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
|
||||
items_(items),
|
||||
ptr_{0} {}
|
||||
|
||||
ScorerWrapper(Ptr<models::IModel> encdec,
|
||||
const std::string& name,
|
||||
float weight,
|
||||
@ -97,7 +107,9 @@ public:
|
||||
|
||||
virtual void init(Ptr<ExpressionGraph> graph) override {
|
||||
graph->switchParams(getName());
|
||||
if(ptr_)
|
||||
if(!items_.empty())
|
||||
encdec_->load(graph, items_);
|
||||
else if(ptr_)
|
||||
encdec_->mmap(graph, ptr_);
|
||||
else
|
||||
encdec_->load(graph, fname_);
|
||||
@ -142,12 +154,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
float weight,
|
||||
std::vector<io::Item> items,
|
||||
Ptr<Options> options);
|
||||
|
||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
float weight,
|
||||
const std::string& model,
|
||||
Ptr<Options> config);
|
||||
|
||||
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
|
||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
|
||||
|
||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||
float weight,
|
||||
|
@ -20,12 +20,7 @@
|
||||
#include "translator/scorers.h"
|
||||
|
||||
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
|
||||
// @TODO: add this as an actual feature.
|
||||
#define MMAP 0
|
||||
|
||||
#if MMAP
|
||||
#include "3rd_party/mio/mio.hpp"
|
||||
#endif
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -42,9 +37,8 @@ private:
|
||||
|
||||
size_t numDevices_;
|
||||
|
||||
#if MMAP
|
||||
std::vector<mio::mmap_source> mmaps_;
|
||||
#endif
|
||||
std::vector<mio::mmap_source> model_mmaps_; // map
|
||||
std::vector<std::vector<io::Item>> model_items_; // non-mmap
|
||||
|
||||
public:
|
||||
Translate(Ptr<Options> options)
|
||||
@ -76,15 +70,21 @@ public:
|
||||
scorers_.resize(numDevices_);
|
||||
graphs_.resize(numDevices_);
|
||||
|
||||
#if MMAP
|
||||
auto models = options->get<std::vector<std::string>>("models");
|
||||
if(options_->get<bool>("model-mmap", false)) {
|
||||
for(auto model : models) {
|
||||
marian::filesystem::Path modelPath(model);
|
||||
ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
|
||||
"Non-binarized models cannot be mmapped");
|
||||
mmaps_.push_back(std::move(mio::mmap_source(model)));
|
||||
ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
|
||||
LOG(info, "Loading model from {}", model);
|
||||
model_mmaps_.push_back(mio::mmap_source(model));
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(auto model : models) {
|
||||
LOG(info, "Loading model from {}", model);
|
||||
auto items = io::loadItems(model);
|
||||
model_items_.push_back(std::move(items));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t id = 0;
|
||||
for(auto device : devices) {
|
||||
@ -101,11 +101,14 @@ public:
|
||||
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||
graphs_[id] = graph;
|
||||
|
||||
#if MMAP
|
||||
auto scorers = createScorers(options_, mmaps_);
|
||||
#else
|
||||
auto scorers = createScorers(options_);
|
||||
#endif
|
||||
std::vector<Ptr<Scorer>> scorers;
|
||||
if(options_->get<bool>("model-mmap", false)) {
|
||||
scorers = createScorers(options_, model_mmaps_);
|
||||
}
|
||||
else {
|
||||
scorers = createScorers(options_, model_items_);
|
||||
}
|
||||
|
||||
for(auto scorer : scorers) {
|
||||
scorer->init(graph);
|
||||
if(shortlistGenerator_)
|
||||
@ -288,6 +291,14 @@ public:
|
||||
auto devices = Config::getDevices(options_);
|
||||
numDevices_ = devices.size();
|
||||
|
||||
// preload models
|
||||
std::vector<std::vector<io::Item>> model_items_;
|
||||
auto models = options->get<std::vector<std::string>>("models");
|
||||
for(auto model : models) {
|
||||
auto items = io::loadItems(model);
|
||||
model_items_.push_back(std::move(items));
|
||||
}
|
||||
|
||||
// initialize scorers
|
||||
for(auto device : devices) {
|
||||
auto graph = New<ExpressionGraph>(true);
|
||||
@ -303,7 +314,7 @@ public:
|
||||
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||
graphs_.push_back(graph);
|
||||
|
||||
auto scorers = createScorers(options_);
|
||||
auto scorers = createScorers(options_, model_items_);
|
||||
for(auto scorer : scorers) {
|
||||
scorer->init(graph);
|
||||
if(shortlistGenerator_)
|
||||
|
Loading…
Reference in New Issue
Block a user