Scorer model loading (#860)

* Add MMAP as an option
* Use io::isBin
* Allow getYamlFromModel from an Item vector
* ScorerWrapper can now load on to a graph from Item vector
The interface IEncoderDecoder can now call graph loads directly from an
Item Vector.
* Translator loads model before creating scorers
Scorers are created from an Item vector
* Replace model-config try-catch with check using IsNull
* Prefer empty vs size
* load by items should be pure virtual
* Stepwise forward load to encdec
* nematus can load from items
* amun can load from items
* loadItems in TranslateService
* Remove logging
* Remove by filename scorer functions
* Replace by filename createScorer
* Explicitly provide default value for get model-mmap
* CLI option for model-mmap only for translation and CPU compile
* Ensure model-mmap option is CPU only
* Remove move on temporary object
* Reinstate log messages for model loading in Amun / Nematus
* Add log messages for model loading in scorers

Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
This commit is contained in:
Graeme Nail 2022-01-18 12:58:52 +00:00 committed by GitHub
parent c84599d08a
commit b29cc07a95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 162 additions and 64 deletions

View File

@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
- Add unit tests for binary files.
- Fix compilation with OMP
- Added `--model-mmap` option to enable mmap loading for CPU-based translation
- Compute aligned memory sizes using exact sizing
- Support for loading lexical shortlist from a binary blob
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option

View File

@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Path prefix for pre-trained model to initialize model weights");
}
}
#ifdef COMPILE_CPU
if(mode_ == cli::mode::translation) {
cli.add<bool>("--model-mmap",
"Use memory-mapping when loading model (CPU only)");
}
#endif
cli.add<bool>("--ignore-model-config",
"Ignore the model configuration saved in npz file");
cli.add<std::string>("--type",

View File

@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const {
ABORT_IF(models.empty() && configs.empty(),
"You need to provide at least one model file or a config file");
ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
"Model MMAP is CPU-only, please use --cpu-threads");
for(const auto& modelFile : models) {
filesystem::Path modelPath(modelFile);
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);

View File

@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml,
yaml = YAML::Load(item.data());
}
// Load YAML from item
void getYamlFromModel(YAML::Node& yaml,
const std::string& varName,
const std::vector<Item>& items) {
for(auto& item : items) {
if(item.name == varName) {
yaml = YAML::Load(item.data());
return;
}
}
}
void addMetaToItems(const std::string& meta,
const std::string& varName,
std::vector<io::Item>& items) {

View File

@ -21,6 +21,7 @@ bool isBin(const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector<Item>& items);
void addMetaToItems(const std::string& meta,
const std::string& varName,

View File

@ -739,7 +739,7 @@ public:
public:
/** Load model (mainly parameter objects) from array of io::Items */
void load(std::vector<io::Item>& ioItems, bool markReloaded = true) {
void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
setReloaded(false);
for(auto& item : ioItems) {
std::string pName = item.name;

View File

@ -36,7 +36,7 @@ public:
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
const std::vector<io::Item>& items,
bool /*markedReloaded*/ = true) override {
std::map<std::string, std::string> nameMap
= {{"decoder_U", "decoder_cell1_U"},
@ -89,9 +89,7 @@ public:
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
nameMap["Wemb"] = "Wemb";
LOG(info, "Loading model from {}", name);
// load items from .npz file
auto ioItems = io::loadItems(name);
auto ioItems = items;
// map names and remove a dummy matrices
for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@ -120,6 +118,14 @@ public:
graph->load(ioItems);
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name);
auto ioItems = io::loadItems(name);
load(graph, ioItems);
}
void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override {

View File

@ -325,6 +325,12 @@ protected:
public:
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) override {
encdec_->load(graph, items, markedReloaded);
}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override {

View File

@ -144,6 +144,12 @@ std::string EncoderDecoder::getModelParametersAsString() {
return std::string(out.c_str());
}
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded) {
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded) {

View File

@ -12,6 +12,11 @@ namespace marian {
class IEncoderDecoder : public models::IModel {
public:
virtual ~IEncoderDecoder() {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) = 0;
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override
@ -91,6 +96,10 @@ public:
void push_back(Ptr<DecoderBase> decoder);
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) override;
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override;

View File

@ -26,11 +26,9 @@ public:
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
const std::vector<io::Item>& items,
bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name);
// load items from .npz file
auto ioItems = io::loadItems(name);
auto ioItems = items;
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@ -41,7 +39,7 @@ public:
it->shape.set(0, 1);
it->shape.set(1, dim);
}
if(it->name == "decoder_c_tt") {
it = ioItems.erase(it);
} else if(it->name == "uidx") {
@ -59,6 +57,14 @@ public:
graph->load(ioItems);
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name);
auto ioItems = io::loadItems(name);
load(graph, ioItems);
}
void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override {

View File

@ -5,7 +5,7 @@ namespace marian {
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
const std::string& model,
std::vector<io::Item> items,
Ptr<Options> options) {
options->set("inference", true);
std::string type = options->get<std::string>("type");
@ -22,7 +22,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
return New<ScorerWrapper>(encdec, fname, weight, model);
return New<ScorerWrapper>(encdec, fname, weight, items);
}
Ptr<Scorer> scorerByType(const std::string& fname,
@ -47,30 +47,30 @@ Ptr<Scorer> scorerByType(const std::string& fname,
return New<ScorerWrapper>(encdec, fname, weight, ptr);
}
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models) {
std::vector<Ptr<Scorer>> scorers;
auto models = options->get<std::vector<std::string>>("models");
std::vector<float> weights(models.size(), 1.f);
if(options->hasAndNotEmpty("weights"))
weights = options->get<std::vector<float>>("weights");
bool isPrevRightLeft = false; // if the previous model was a right-to-left model
size_t i = 0;
for(auto model : models) {
for(auto items : models) {
std::string fname = "F" + std::to_string(i);
// load options specific for the scorer
auto modelOptions = New<Options>(options->clone());
try {
if(!options->get<bool>("ignore-model-config")) {
YAML::Node modelYaml;
io::getYamlFromModel(modelYaml, "special:model.yml", model);
if(!options->get<bool>("ignore-model-config")) {
YAML::Node modelYaml;
io::getYamlFromModel(modelYaml, "special:model.yml", items);
if(!modelYaml.IsNull()) {
LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true);
}
} catch(std::runtime_error&) {
LOG(warn, "No model settings found in model file");
else {
LOG(warn, "No model settings found in model file");
}
}
// l2r and r2l cannot be used in the same ensemble
@ -85,13 +85,24 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
}
}
scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
scorers.push_back(scorerByType(fname, weights[i], items, modelOptions));
i++;
}
return scorers;
}
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
std::vector<std::vector<io::Item>> model_items;
auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) {
auto items = io::loadItems(model);
model_items.push_back(std::move(items));
}
return createScorers(options, model_items);
}
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
std::vector<Ptr<Scorer>> scorers;
@ -105,14 +116,16 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
// load options specific for the scorer
auto modelOptions = New<Options>(options->clone());
try {
if(!options->get<bool>("ignore-model-config")) {
YAML::Node modelYaml;
io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
if(!options->get<bool>("ignore-model-config")) {
YAML::Node modelYaml;
io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
if(!modelYaml.IsNull()) {
LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true);
}
} catch(std::runtime_error&) {
LOG(warn, "No model settings found in model file");
else {
LOG(warn, "No model settings found in model file");
}
}
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));

View File

@ -73,9 +73,19 @@ class ScorerWrapper : public Scorer {
private:
Ptr<IEncoderDecoder> encdec_;
std::string fname_;
std::vector<io::Item> items_;
const void* ptr_;
public:
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
std::vector<io::Item>& items)
: Scorer(name, weight),
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
items_(items),
ptr_{0} {}
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
@ -97,7 +107,9 @@ public:
virtual void init(Ptr<ExpressionGraph> graph) override {
graph->switchParams(getName());
if(ptr_)
if(!items_.empty())
encdec_->load(graph, items_);
else if(ptr_)
encdec_->mmap(graph, ptr_);
else
encdec_->load(graph, fname_);
@ -142,12 +154,19 @@ public:
}
};
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
std::vector<io::Item> items,
Ptr<Options> options);
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
const std::string& model,
Ptr<Options> config);
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,

View File

@ -20,12 +20,7 @@
#include "translator/scorers.h"
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
// @TODO: add this as an actual feature.
#define MMAP 0
#if MMAP
#include "3rd_party/mio/mio.hpp"
#endif
namespace marian {
@ -42,9 +37,8 @@ private:
size_t numDevices_;
#if MMAP
std::vector<mio::mmap_source> mmaps_;
#endif
std::vector<mio::mmap_source> model_mmaps_; // map
std::vector<std::vector<io::Item>> model_items_; // non-mmap
public:
Translate(Ptr<Options> options)
@ -76,15 +70,21 @@ public:
scorers_.resize(numDevices_);
graphs_.resize(numDevices_);
#if MMAP
auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) {
marian::filesystem::Path modelPath(model);
ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
"Non-binarized models cannot be mmapped");
mmaps_.push_back(std::move(mio::mmap_source(model)));
if(options_->get<bool>("model-mmap", false)) {
for(auto model : models) {
ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
LOG(info, "Loading model from {}", model);
model_mmaps_.push_back(mio::mmap_source(model));
}
}
else {
for(auto model : models) {
LOG(info, "Loading model from {}", model);
auto items = io::loadItems(model);
model_items_.push_back(std::move(items));
}
}
#endif
size_t id = 0;
for(auto device : devices) {
@ -101,11 +101,14 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
#if MMAP
auto scorers = createScorers(options_, mmaps_);
#else
auto scorers = createScorers(options_);
#endif
std::vector<Ptr<Scorer>> scorers;
if(options_->get<bool>("model-mmap", false)) {
scorers = createScorers(options_, model_mmaps_);
}
else {
scorers = createScorers(options_, model_items_);
}
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)
@ -146,11 +149,11 @@ public:
std::mutex syncCounts;
// timer and counters for total elapsed time and statistics
std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
size_t totBatches = 0;
size_t totLines = 0;
size_t totSourceTokens = 0;
// timer and counters for elapsed time and statistics between updates
std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
size_t curBatches = 0;
@ -176,7 +179,7 @@ public:
bg.prepare();
for(auto batch : bg) {
auto task = [=, &syncCounts,
&totBatches, &totLines, &totSourceTokens, &totTimer,
&totBatches, &totLines, &totSourceTokens, &totTimer,
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@ -200,12 +203,12 @@ public:
}
// if we asked for speed information display this
if(statFreq.n > 0) {
if(statFreq.n > 0) {
std::lock_guard<std::mutex> lock(syncCounts);
totBatches++;
totBatches++;
totLines += batch->size();
totSourceTokens += batch->front()->batchWords();
curBatches++;
curLines += batch->size();
curSourceTokens += batch->front()->batchWords();
@ -214,10 +217,10 @@ public:
double totTime = totTimer->elapsed();
double curTime = curTimer->elapsed();
LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
// reset stats between updates
curBatches = curLines = curSourceTokens = 0;
curTimer.reset(new timer::Timer());
@ -230,12 +233,12 @@ public:
// make sure threads are joined before other local variables get de-allocated
threadPool.join_all();
// display final speed numbers over total translation if intermediate displays were requested
if(statFreq.n > 0) {
double totTime = totTimer->elapsed();
LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
}
}
@ -288,6 +291,14 @@ public:
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();
// preload models
std::vector<std::vector<io::Item>> model_items_;
auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) {
auto items = io::loadItems(model);
model_items_.push_back(std::move(items));
}
// initialize scorers
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
@ -303,7 +314,7 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
auto scorers = createScorers(options_);
auto scorers = createScorers(options_, model_items_);
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)