diff --git a/CHANGELOG.md b/CHANGELOG.md index a5dd305f..d42c652e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Dynamic gradient-scaling with `--dynamic-gradient-scaling`. - Add unit tests for binary files. - Fix compilation with OMP +- Added `--model-mmap` option to enable mmap loading for CPU-based translation - Compute aligned memory sizes using exact sizing - Support for loading lexical shortlist from a binary blob - Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 8da9520c..9705d5b7 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Path prefix for pre-trained model to initialize model weights"); } } - +#ifdef COMPILE_CPU + if(mode_ == cli::mode::translation) { + cli.add("--model-mmap", + "Use memory-mapping when loading model (CPU only)"); + } +#endif cli.add("--ignore-model-config", "Ignore the model configuration saved in npz file"); cli.add("--type", diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index fea7578f..b0230da9 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const { ABORT_IF(models.empty() && configs.empty(), "You need to provide at least one model file or a config file"); + ABORT_IF(get("model-mmap") && get("cpu-threads") == 0, + "Model MMAP is CPU-only, please use --cpu-threads"); + for(const auto& modelFile : models) { filesystem::Path modelPath(modelFile); ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile); diff --git a/src/common/io.cpp b/src/common/io.cpp index a9984b5d..e0b3f39a 100644 --- a/src/common/io.cpp +++ b/src/common/io.cpp @@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml, yaml = YAML::Load(item.data()); } +// Load YAML from item +void getYamlFromModel(YAML::Node& yaml, + const std::string& varName, + const std::vector& items) { + for(auto& item : items) { + if(item.name == varName) { + yaml = YAML::Load(item.data()); + return; + } + } +} + void addMetaToItems(const std::string& meta, const std::string& varName, std::vector& items) { diff --git a/src/common/io.h b/src/common/io.h index 2d18d66e..3f340ed2 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -21,6 +21,7 @@ bool isBin(const std::string& fileName); void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName); void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr); +void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector& items); void addMetaToItems(const std::string& meta, const std::string& varName, diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index 553a5d63..c532abff 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -739,7 +739,7 @@ public: public: /** Load model (mainly parameter objects) from array of io::Items */ - void load(std::vector& ioItems, bool markReloaded = true) { + void load(const std::vector& ioItems, bool markReloaded = true) { setReloaded(false); for(auto& item : ioItems) { std::string pName = item.name; diff --git a/src/models/amun.h b/src/models/amun.h index 1bfda269..135ce359 100644 --- a/src/models/amun.h +++ b/src/models/amun.h @@ -36,7 +36,7 @@ public: } void load(Ptr graph, - const std::string& name, + const std::vector& items, bool /*markedReloaded*/ = true) override { std::map nameMap = {{"decoder_U", "decoder_cell1_U"}, @@ -89,9 +89,7 @@ public: if(opt("tied-embeddings-src") || opt("tied-embeddings-all")) nameMap["Wemb"] = "Wemb"; - LOG(info, "Loading model from {}", name); - // load items from .npz file - auto ioItems = io::loadItems(name); + auto ioItems = items; // map names and remove a dummy matrices for(auto it = ioItems.begin(); it != ioItems.end();) { // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size @@ -120,6 +118,14 @@ public: graph->load(ioItems); } + void load(Ptr graph, + const std::string& name, + bool /*markReloaded*/ = true) override { + LOG(info, "Loading model from {}", name); + auto ioItems = io::loadItems(name); + load(graph, ioItems); + } + void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override { diff --git a/src/models/costs.h b/src/models/costs.h index e5463bfd..982a13c5 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -325,6 +325,12 @@ protected: public: Stepwise(Ptr encdec, Ptr cost) : encdec_(encdec), cost_(cost) {} + virtual void load(Ptr graph, + const std::vector& items, + bool markedReloaded = true) override { + encdec_->load(graph, items, markedReloaded); + } + virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override { diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 66ff16ce..bb938ee5 100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -144,6 +144,12 @@ std::string EncoderDecoder::getModelParametersAsString() { return std::string(out.c_str()); } +void EncoderDecoder::load(Ptr graph, + const std::vector& items, + bool markedReloaded) { + graph->load(items, markedReloaded && !opt("ignore-model-config", false)); +} + void EncoderDecoder::load(Ptr graph, const std::string& name, bool markedReloaded) { diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h index 92c1647f..0fbf3faf 100644 --- a/src/models/encoder_decoder.h +++ b/src/models/encoder_decoder.h @@ -12,6 +12,11 @@ namespace marian { class IEncoderDecoder : public models::IModel { public: virtual ~IEncoderDecoder() {} + + virtual void load(Ptr graph, + const std::vector& items, + bool markedReloaded = true) = 0; + virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override @@ -91,6 +96,10 @@ public: void push_back(Ptr decoder); + virtual void load(Ptr graph, + const std::vector& items, + bool markedReloaded = true) override; + virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override; diff --git a/src/models/nematus.h b/src/models/nematus.h index 730418e5..aee8e3b0 100644 --- a/src/models/nematus.h +++ b/src/models/nematus.h @@ -26,11 +26,9 @@ public: } void load(Ptr graph, - const std::string& name, + const std::vector& items, bool /*markReloaded*/ = true) override { - LOG(info, "Loading model from {}", name); - // load items from .npz file - auto ioItems = io::loadItems(name); + auto ioItems = items; // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node for(auto it = ioItems.begin(); it != ioItems.end();) { // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size @@ -41,7 +39,7 @@ public: it->shape.set(0, 1); it->shape.set(1, dim); } - + if(it->name == "decoder_c_tt") { it = ioItems.erase(it); } else if(it->name == "uidx") { @@ -59,6 +57,14 @@ public: graph->load(ioItems); } + void load(Ptr graph, + const std::string& name, + bool /*markReloaded*/ = true) override { + LOG(info, "Loading model from {}", name); + auto ioItems = io::loadItems(name); + load(graph, ioItems); + } + void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override { diff --git a/src/translator/scorers.cpp b/src/translator/scorers.cpp index d1c8b160..60ec03dd 100644 --- a/src/translator/scorers.cpp +++ b/src/translator/scorers.cpp @@ -5,7 +5,7 @@ namespace marian { Ptr scorerByType(const std::string& fname, float weight, - const std::string& model, + std::vector items, Ptr options) { options->set("inference", true); std::string type = options->get("type"); @@ -22,7 +22,7 @@ Ptr scorerByType(const std::string& fname, LOG(info, "Loading scorer of type {} as feature {}", type, fname); - return New(encdec, fname, weight, model); + return New(encdec, fname, weight, items); } Ptr scorerByType(const std::string& fname, @@ -47,30 +47,30 @@ Ptr scorerByType(const std::string& fname, return New(encdec, fname, weight, ptr); } -std::vector> createScorers(Ptr options) { +std::vector> createScorers(Ptr options, const std::vector> models) { std::vector> scorers; - auto models = options->get>("models"); - std::vector weights(models.size(), 1.f); if(options->hasAndNotEmpty("weights")) weights = options->get>("weights"); bool isPrevRightLeft = false; // if the previous model was a right-to-left model size_t i = 0; - for(auto model : models) { + for(auto items : models) { std::string fname = "F" + std::to_string(i); // load options specific for the scorer auto modelOptions = New(options->clone()); - try { - if(!options->get("ignore-model-config")) { - YAML::Node modelYaml; - io::getYamlFromModel(modelYaml, "special:model.yml", model); + if(!options->get("ignore-model-config")) { + YAML::Node modelYaml; + io::getYamlFromModel(modelYaml, "special:model.yml", items); + if(!modelYaml.IsNull()) { + LOG(info, "Loaded model config"); modelOptions->merge(modelYaml, true); } - } catch(std::runtime_error&) { - LOG(warn, "No model settings found in model file"); + else { + LOG(warn, "No model settings found in model file"); + } } // l2r and r2l cannot be used in the same ensemble @@ -85,13 +85,24 @@ std::vector> createScorers(Ptr options) { } } - scorers.push_back(scorerByType(fname, weights[i], model, modelOptions)); + scorers.push_back(scorerByType(fname, weights[i], items, modelOptions)); i++; } return scorers; } +std::vector> createScorers(Ptr options) { + std::vector> model_items; + auto models = options->get>("models"); + for(auto model : models) { + auto items = io::loadItems(model); + model_items.push_back(std::move(items)); + } + + return createScorers(options, model_items); +} + std::vector> createScorers(Ptr options, const std::vector& ptrs) { std::vector> scorers; @@ -105,14 +116,16 @@ std::vector> createScorers(Ptr options, const std::vector(options->clone()); - try { - if(!options->get("ignore-model-config")) { - YAML::Node modelYaml; - io::getYamlFromModel(modelYaml, "special:model.yml", ptr); + if(!options->get("ignore-model-config")) { + YAML::Node modelYaml; + io::getYamlFromModel(modelYaml, "special:model.yml", ptr); + if(!modelYaml.IsNull()) { + LOG(info, "Loaded model config"); modelOptions->merge(modelYaml, true); } - } catch(std::runtime_error&) { - LOG(warn, "No model settings found in model file"); + else { + LOG(warn, "No model settings found in model file"); + } } scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions)); diff --git a/src/translator/scorers.h b/src/translator/scorers.h index a5a0be2c..72ebff5d 100644 --- a/src/translator/scorers.h +++ b/src/translator/scorers.h @@ -73,9 +73,19 @@ class ScorerWrapper : public Scorer { private: Ptr encdec_; std::string fname_; + std::vector items_; const void* ptr_; public: + ScorerWrapper(Ptr encdec, + const std::string& name, + float weight, + std::vector& items) + : Scorer(name, weight), + encdec_(std::static_pointer_cast(encdec)), + items_(items), + ptr_{0} {} + ScorerWrapper(Ptr encdec, const std::string& name, float weight, @@ -97,7 +107,9 @@ public: virtual void init(Ptr graph) override { graph->switchParams(getName()); - if(ptr_) + if(!items_.empty()) + encdec_->load(graph, items_); + else if(ptr_) encdec_->mmap(graph, ptr_); else encdec_->load(graph, fname_); @@ -142,12 +154,19 @@ public: } }; +Ptr scorerByType(const std::string& fname, + float weight, + std::vector items, + Ptr options); + Ptr scorerByType(const std::string& fname, float weight, const std::string& model, Ptr config); + std::vector> createScorers(Ptr options); +std::vector> createScorers(Ptr options, const std::vector> models); Ptr scorerByType(const std::string& fname, float weight, diff --git a/src/translator/translator.h b/src/translator/translator.h index db1f3d03..4084ced9 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -20,12 +20,7 @@ #include "translator/scorers.h" // currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled. -// @TODO: add this as an actual feature. -#define MMAP 0 - -#if MMAP #include "3rd_party/mio/mio.hpp" -#endif namespace marian { @@ -42,9 +37,8 @@ private: size_t numDevices_; -#if MMAP - std::vector mmaps_; -#endif + std::vector model_mmaps_; // map + std::vector> model_items_; // non-mmap public: Translate(Ptr options) @@ -76,15 +70,21 @@ public: scorers_.resize(numDevices_); graphs_.resize(numDevices_); -#if MMAP auto models = options->get>("models"); - for(auto model : models) { - marian::filesystem::Path modelPath(model); - ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"), - "Non-binarized models cannot be mmapped"); - mmaps_.push_back(std::move(mio::mmap_source(model))); + if(options_->get("model-mmap", false)) { + for(auto model : models) { + ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped"); + LOG(info, "Loading model from {}", model); + model_mmaps_.push_back(mio::mmap_source(model)); + } + } + else { + for(auto model : models) { + LOG(info, "Loading model from {}", model); + auto items = io::loadItems(model); + model_items_.push_back(std::move(items)); + } } -#endif size_t id = 0; for(auto device : devices) { @@ -101,11 +101,14 @@ public: graph->reserveWorkspaceMB(options_->get("workspace")); graphs_[id] = graph; -#if MMAP - auto scorers = createScorers(options_, mmaps_); -#else - auto scorers = createScorers(options_); -#endif + std::vector> scorers; + if(options_->get("model-mmap", false)) { + scorers = createScorers(options_, model_mmaps_); + } + else { + scorers = createScorers(options_, model_items_); + } + for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_) @@ -146,11 +149,11 @@ public: std::mutex syncCounts; // timer and counters for total elapsed time and statistics - std::unique_ptr totTimer(new timer::Timer()); + std::unique_ptr totTimer(new timer::Timer()); size_t totBatches = 0; size_t totLines = 0; size_t totSourceTokens = 0; - + // timer and counters for elapsed time and statistics between updates std::unique_ptr curTimer(new timer::Timer()); size_t curBatches = 0; @@ -176,7 +179,7 @@ public: bg.prepare(); for(auto batch : bg) { auto task = [=, &syncCounts, - &totBatches, &totLines, &totSourceTokens, &totTimer, + &totBatches, &totLines, &totSourceTokens, &totTimer, &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) { thread_local Ptr graph; thread_local std::vector> scorers; @@ -200,12 +203,12 @@ public: } // if we asked for speed information display this - if(statFreq.n > 0) { + if(statFreq.n > 0) { std::lock_guard lock(syncCounts); - totBatches++; + totBatches++; totLines += batch->size(); totSourceTokens += batch->front()->batchWords(); - + curBatches++; curLines += batch->size(); curSourceTokens += batch->front()->batchWords(); @@ -214,10 +217,10 @@ public: double totTime = totTimer->elapsed(); double curTime = curTimer->elapsed(); - LOG(info, - "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime); - + // reset stats between updates curBatches = curLines = curSourceTokens = 0; curTimer.reset(new timer::Timer()); @@ -230,12 +233,12 @@ public: // make sure threads are joined before other local variables get de-allocated threadPool.join_all(); - + // display final speed numbers over total translation if intermediate displays were requested if(statFreq.n > 0) { double totTime = totTimer->elapsed(); - LOG(info, - "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", + LOG(info, + "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime); } } @@ -288,6 +291,14 @@ public: auto devices = Config::getDevices(options_); numDevices_ = devices.size(); + // preload models + std::vector> model_items_; + auto models = options->get>("models"); + for(auto model : models) { + auto items = io::loadItems(model); + model_items_.push_back(std::move(items)); + } + // initialize scorers for(auto device : devices) { auto graph = New(true); @@ -303,7 +314,7 @@ public: graph->reserveWorkspaceMB(options_->get("workspace")); graphs_.push_back(graph); - auto scorers = createScorers(options_); + auto scorers = createScorers(options_, model_items_); for(auto scorer : scorers) { scorer->init(graph); if(shortlistGenerator_)