Scorer model loading (#860)

* Add MMAP as an option
* Use io::isBin
* Allow getYamlFromModel from an Item vector
* ScorerWrapper can now load on to a graph from Item vector
The interface IEncoderDecoder can now call graph loads directly from an
Item Vector.
* Translator loads model before creating scorers
Scorers are created from an Item vector
* Replace model-config try-catch with check using IsNull
* Prefer empty vs size
* load by items should be pure virtual
* Stepwise forward load to encdec
* nematus can load from items
* amun can load from items
* loadItems in TranslateService
* Remove logging
* Remove by filename scorer functions
* Replace by filename createScorer
* Explicitly provide default value for get model-mmap
* CLI option for model-mmap only for translation and CPU compile
* Ensure model-mmap option is CPU only
* Remove move on temporary object
* Reinstate log messages for model loading in Amun / Nematus
* Add log messages for model loading in scorers

Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
This commit is contained in:
Graeme Nail 2022-01-18 12:58:52 +00:00 committed by GitHub
parent c84599d08a
commit b29cc07a95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 162 additions and 64 deletions

View File

@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`. - Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
- Add unit tests for binary files. - Add unit tests for binary files.
- Fix compilation with OMP - Fix compilation with OMP
- Added `--model-mmap` option to enable mmap loading for CPU-based translation
- Compute aligned memory sizes using exact sizing - Compute aligned memory sizes using exact sizing
- Support for loading lexical shortlist from a binary blob - Support for loading lexical shortlist from a binary blob
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option - Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option

View File

@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Path prefix for pre-trained model to initialize model weights"); "Path prefix for pre-trained model to initialize model weights");
} }
} }
#ifdef COMPILE_CPU
if(mode_ == cli::mode::translation) {
cli.add<bool>("--model-mmap",
"Use memory-mapping when loading model (CPU only)");
}
#endif
cli.add<bool>("--ignore-model-config", cli.add<bool>("--ignore-model-config",
"Ignore the model configuration saved in npz file"); "Ignore the model configuration saved in npz file");
cli.add<std::string>("--type", cli.add<std::string>("--type",

View File

@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const {
ABORT_IF(models.empty() && configs.empty(), ABORT_IF(models.empty() && configs.empty(),
"You need to provide at least one model file or a config file"); "You need to provide at least one model file or a config file");
ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
"Model MMAP is CPU-only, please use --cpu-threads");
for(const auto& modelFile : models) { for(const auto& modelFile : models) {
filesystem::Path modelPath(modelFile); filesystem::Path modelPath(modelFile);
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile); ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);

View File

@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml,
yaml = YAML::Load(item.data()); yaml = YAML::Load(item.data());
} }
// Load YAML from item
void getYamlFromModel(YAML::Node& yaml,
const std::string& varName,
const std::vector<Item>& items) {
for(auto& item : items) {
if(item.name == varName) {
yaml = YAML::Load(item.data());
return;
}
}
}
void addMetaToItems(const std::string& meta, void addMetaToItems(const std::string& meta,
const std::string& varName, const std::string& varName,
std::vector<io::Item>& items) { std::vector<io::Item>& items) {

View File

@ -21,6 +21,7 @@ bool isBin(const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName); void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr); void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector<Item>& items);
void addMetaToItems(const std::string& meta, void addMetaToItems(const std::string& meta,
const std::string& varName, const std::string& varName,

View File

@ -739,7 +739,7 @@ public:
public: public:
/** Load model (mainly parameter objects) from array of io::Items */ /** Load model (mainly parameter objects) from array of io::Items */
void load(std::vector<io::Item>& ioItems, bool markReloaded = true) { void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
setReloaded(false); setReloaded(false);
for(auto& item : ioItems) { for(auto& item : ioItems) {
std::string pName = item.name; std::string pName = item.name;

View File

@ -36,7 +36,7 @@ public:
} }
void load(Ptr<ExpressionGraph> graph, void load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::vector<io::Item>& items,
bool /*markedReloaded*/ = true) override { bool /*markedReloaded*/ = true) override {
std::map<std::string, std::string> nameMap std::map<std::string, std::string> nameMap
= {{"decoder_U", "decoder_cell1_U"}, = {{"decoder_U", "decoder_cell1_U"},
@ -89,9 +89,7 @@ public:
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
nameMap["Wemb"] = "Wemb"; nameMap["Wemb"] = "Wemb";
LOG(info, "Loading model from {}", name); auto ioItems = items;
// load items from .npz file
auto ioItems = io::loadItems(name);
// map names and remove a dummy matrices // map names and remove a dummy matrices
for(auto it = ioItems.begin(); it != ioItems.end();) { for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@ -120,6 +118,14 @@ public:
graph->load(ioItems); graph->load(ioItems);
} }
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name);
auto ioItems = io::loadItems(name);
load(graph, ioItems);
}
void save(Ptr<ExpressionGraph> graph, void save(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool saveTranslatorConfig = false) override { bool saveTranslatorConfig = false) override {

View File

@ -325,6 +325,12 @@ protected:
public: public:
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {} Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) override {
encdec_->load(graph, items, markedReloaded);
}
virtual void load(Ptr<ExpressionGraph> graph, virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool markedReloaded = true) override { bool markedReloaded = true) override {

View File

@ -144,6 +144,12 @@ std::string EncoderDecoder::getModelParametersAsString() {
return std::string(out.c_str()); return std::string(out.c_str());
} }
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded) {
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void EncoderDecoder::load(Ptr<ExpressionGraph> graph, void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool markedReloaded) { bool markedReloaded) {

View File

@ -12,6 +12,11 @@ namespace marian {
class IEncoderDecoder : public models::IModel { class IEncoderDecoder : public models::IModel {
public: public:
virtual ~IEncoderDecoder() {} virtual ~IEncoderDecoder() {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) = 0;
virtual void load(Ptr<ExpressionGraph> graph, virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool markedReloaded = true) override bool markedReloaded = true) override
@ -91,6 +96,10 @@ public:
void push_back(Ptr<DecoderBase> decoder); void push_back(Ptr<DecoderBase> decoder);
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) override;
virtual void load(Ptr<ExpressionGraph> graph, virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool markedReloaded = true) override; bool markedReloaded = true) override;

View File

@ -26,11 +26,9 @@ public:
} }
void load(Ptr<ExpressionGraph> graph, void load(Ptr<ExpressionGraph> graph,
const std::string& name, const std::vector<io::Item>& items,
bool /*markReloaded*/ = true) override { bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name); auto ioItems = items;
// load items from .npz file
auto ioItems = io::loadItems(name);
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
for(auto it = ioItems.begin(); it != ioItems.end();) { for(auto it = ioItems.begin(); it != ioItems.end();) {
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size // for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
@ -41,7 +39,7 @@ public:
it->shape.set(0, 1); it->shape.set(0, 1);
it->shape.set(1, dim); it->shape.set(1, dim);
} }
if(it->name == "decoder_c_tt") { if(it->name == "decoder_c_tt") {
it = ioItems.erase(it); it = ioItems.erase(it);
} else if(it->name == "uidx") { } else if(it->name == "uidx") {
@ -59,6 +57,14 @@ public:
graph->load(ioItems); graph->load(ioItems);
} }
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool /*markReloaded*/ = true) override {
LOG(info, "Loading model from {}", name);
auto ioItems = io::loadItems(name);
load(graph, ioItems);
}
void save(Ptr<ExpressionGraph> graph, void save(Ptr<ExpressionGraph> graph,
const std::string& name, const std::string& name,
bool saveTranslatorConfig = false) override { bool saveTranslatorConfig = false) override {

View File

@ -5,7 +5,7 @@ namespace marian {
Ptr<Scorer> scorerByType(const std::string& fname, Ptr<Scorer> scorerByType(const std::string& fname,
float weight, float weight,
const std::string& model, std::vector<io::Item> items,
Ptr<Options> options) { Ptr<Options> options) {
options->set("inference", true); options->set("inference", true);
std::string type = options->get<std::string>("type"); std::string type = options->get<std::string>("type");
@ -22,7 +22,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
LOG(info, "Loading scorer of type {} as feature {}", type, fname); LOG(info, "Loading scorer of type {} as feature {}", type, fname);
return New<ScorerWrapper>(encdec, fname, weight, model); return New<ScorerWrapper>(encdec, fname, weight, items);
} }
Ptr<Scorer> scorerByType(const std::string& fname, Ptr<Scorer> scorerByType(const std::string& fname,
@ -47,30 +47,30 @@ Ptr<Scorer> scorerByType(const std::string& fname,
return New<ScorerWrapper>(encdec, fname, weight, ptr); return New<ScorerWrapper>(encdec, fname, weight, ptr);
} }
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) { std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models) {
std::vector<Ptr<Scorer>> scorers; std::vector<Ptr<Scorer>> scorers;
auto models = options->get<std::vector<std::string>>("models");
std::vector<float> weights(models.size(), 1.f); std::vector<float> weights(models.size(), 1.f);
if(options->hasAndNotEmpty("weights")) if(options->hasAndNotEmpty("weights"))
weights = options->get<std::vector<float>>("weights"); weights = options->get<std::vector<float>>("weights");
bool isPrevRightLeft = false; // if the previous model was a right-to-left model bool isPrevRightLeft = false; // if the previous model was a right-to-left model
size_t i = 0; size_t i = 0;
for(auto model : models) { for(auto items : models) {
std::string fname = "F" + std::to_string(i); std::string fname = "F" + std::to_string(i);
// load options specific for the scorer // load options specific for the scorer
auto modelOptions = New<Options>(options->clone()); auto modelOptions = New<Options>(options->clone());
try { if(!options->get<bool>("ignore-model-config")) {
if(!options->get<bool>("ignore-model-config")) { YAML::Node modelYaml;
YAML::Node modelYaml; io::getYamlFromModel(modelYaml, "special:model.yml", items);
io::getYamlFromModel(modelYaml, "special:model.yml", model); if(!modelYaml.IsNull()) {
LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true); modelOptions->merge(modelYaml, true);
} }
} catch(std::runtime_error&) { else {
LOG(warn, "No model settings found in model file"); LOG(warn, "No model settings found in model file");
}
} }
// l2r and r2l cannot be used in the same ensemble // l2r and r2l cannot be used in the same ensemble
@ -85,13 +85,24 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
} }
} }
scorers.push_back(scorerByType(fname, weights[i], model, modelOptions)); scorers.push_back(scorerByType(fname, weights[i], items, modelOptions));
i++; i++;
} }
return scorers; return scorers;
} }
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
std::vector<std::vector<io::Item>> model_items;
auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) {
auto items = io::loadItems(model);
model_items.push_back(std::move(items));
}
return createScorers(options, model_items);
}
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) { std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
std::vector<Ptr<Scorer>> scorers; std::vector<Ptr<Scorer>> scorers;
@ -105,14 +116,16 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
// load options specific for the scorer // load options specific for the scorer
auto modelOptions = New<Options>(options->clone()); auto modelOptions = New<Options>(options->clone());
try { if(!options->get<bool>("ignore-model-config")) {
if(!options->get<bool>("ignore-model-config")) { YAML::Node modelYaml;
YAML::Node modelYaml; io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
io::getYamlFromModel(modelYaml, "special:model.yml", ptr); if(!modelYaml.IsNull()) {
LOG(info, "Loaded model config");
modelOptions->merge(modelYaml, true); modelOptions->merge(modelYaml, true);
} }
} catch(std::runtime_error&) { else {
LOG(warn, "No model settings found in model file"); LOG(warn, "No model settings found in model file");
}
} }
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions)); scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));

View File

@ -73,9 +73,19 @@ class ScorerWrapper : public Scorer {
private: private:
Ptr<IEncoderDecoder> encdec_; Ptr<IEncoderDecoder> encdec_;
std::string fname_; std::string fname_;
std::vector<io::Item> items_;
const void* ptr_; const void* ptr_;
public: public:
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
std::vector<io::Item>& items)
: Scorer(name, weight),
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
items_(items),
ptr_{0} {}
ScorerWrapper(Ptr<models::IModel> encdec, ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name, const std::string& name,
float weight, float weight,
@ -97,7 +107,9 @@ public:
virtual void init(Ptr<ExpressionGraph> graph) override { virtual void init(Ptr<ExpressionGraph> graph) override {
graph->switchParams(getName()); graph->switchParams(getName());
if(ptr_) if(!items_.empty())
encdec_->load(graph, items_);
else if(ptr_)
encdec_->mmap(graph, ptr_); encdec_->mmap(graph, ptr_);
else else
encdec_->load(graph, fname_); encdec_->load(graph, fname_);
@ -142,12 +154,19 @@ public:
} }
}; };
Ptr<Scorer> scorerByType(const std::string& fname,
float weight,
std::vector<io::Item> items,
Ptr<Options> options);
Ptr<Scorer> scorerByType(const std::string& fname, Ptr<Scorer> scorerByType(const std::string& fname,
float weight, float weight,
const std::string& model, const std::string& model,
Ptr<Options> config); Ptr<Options> config);
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options); std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
Ptr<Scorer> scorerByType(const std::string& fname, Ptr<Scorer> scorerByType(const std::string& fname,
float weight, float weight,

View File

@ -20,12 +20,7 @@
#include "translator/scorers.h" #include "translator/scorers.h"
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled. // currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
// @TODO: add this as an actual feature.
#define MMAP 0
#if MMAP
#include "3rd_party/mio/mio.hpp" #include "3rd_party/mio/mio.hpp"
#endif
namespace marian { namespace marian {
@ -42,9 +37,8 @@ private:
size_t numDevices_; size_t numDevices_;
#if MMAP std::vector<mio::mmap_source> model_mmaps_; // map
std::vector<mio::mmap_source> mmaps_; std::vector<std::vector<io::Item>> model_items_; // non-mmap
#endif
public: public:
Translate(Ptr<Options> options) Translate(Ptr<Options> options)
@ -76,15 +70,21 @@ public:
scorers_.resize(numDevices_); scorers_.resize(numDevices_);
graphs_.resize(numDevices_); graphs_.resize(numDevices_);
#if MMAP
auto models = options->get<std::vector<std::string>>("models"); auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) { if(options_->get<bool>("model-mmap", false)) {
marian::filesystem::Path modelPath(model); for(auto model : models) {
ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"), ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
"Non-binarized models cannot be mmapped"); LOG(info, "Loading model from {}", model);
mmaps_.push_back(std::move(mio::mmap_source(model))); model_mmaps_.push_back(mio::mmap_source(model));
}
}
else {
for(auto model : models) {
LOG(info, "Loading model from {}", model);
auto items = io::loadItems(model);
model_items_.push_back(std::move(items));
}
} }
#endif
size_t id = 0; size_t id = 0;
for(auto device : devices) { for(auto device : devices) {
@ -101,11 +101,14 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph; graphs_[id] = graph;
#if MMAP std::vector<Ptr<Scorer>> scorers;
auto scorers = createScorers(options_, mmaps_); if(options_->get<bool>("model-mmap", false)) {
#else scorers = createScorers(options_, model_mmaps_);
auto scorers = createScorers(options_); }
#endif else {
scorers = createScorers(options_, model_items_);
}
for(auto scorer : scorers) { for(auto scorer : scorers) {
scorer->init(graph); scorer->init(graph);
if(shortlistGenerator_) if(shortlistGenerator_)
@ -146,11 +149,11 @@ public:
std::mutex syncCounts; std::mutex syncCounts;
// timer and counters for total elapsed time and statistics // timer and counters for total elapsed time and statistics
std::unique_ptr<timer::Timer> totTimer(new timer::Timer()); std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
size_t totBatches = 0; size_t totBatches = 0;
size_t totLines = 0; size_t totLines = 0;
size_t totSourceTokens = 0; size_t totSourceTokens = 0;
// timer and counters for elapsed time and statistics between updates // timer and counters for elapsed time and statistics between updates
std::unique_ptr<timer::Timer> curTimer(new timer::Timer()); std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
size_t curBatches = 0; size_t curBatches = 0;
@ -176,7 +179,7 @@ public:
bg.prepare(); bg.prepare();
for(auto batch : bg) { for(auto batch : bg) {
auto task = [=, &syncCounts, auto task = [=, &syncCounts,
&totBatches, &totLines, &totSourceTokens, &totTimer, &totBatches, &totLines, &totSourceTokens, &totTimer,
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) { &curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
thread_local Ptr<ExpressionGraph> graph; thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers; thread_local std::vector<Ptr<Scorer>> scorers;
@ -200,12 +203,12 @@ public:
} }
// if we asked for speed information display this // if we asked for speed information display this
if(statFreq.n > 0) { if(statFreq.n > 0) {
std::lock_guard<std::mutex> lock(syncCounts); std::lock_guard<std::mutex> lock(syncCounts);
totBatches++; totBatches++;
totLines += batch->size(); totLines += batch->size();
totSourceTokens += batch->front()->batchWords(); totSourceTokens += batch->front()->batchWords();
curBatches++; curBatches++;
curLines += batch->size(); curLines += batch->size();
curSourceTokens += batch->front()->batchWords(); curSourceTokens += batch->front()->batchWords();
@ -214,10 +217,10 @@ public:
double totTime = totTimer->elapsed(); double totTime = totTimer->elapsed();
double curTime = curTimer->elapsed(); double curTime = curTimer->elapsed();
LOG(info, LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime); totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
// reset stats between updates // reset stats between updates
curBatches = curLines = curSourceTokens = 0; curBatches = curLines = curSourceTokens = 0;
curTimer.reset(new timer::Timer()); curTimer.reset(new timer::Timer());
@ -230,12 +233,12 @@ public:
// make sure threads are joined before other local variables get de-allocated // make sure threads are joined before other local variables get de-allocated
threadPool.join_all(); threadPool.join_all();
// display final speed numbers over total translation if intermediate displays were requested // display final speed numbers over total translation if intermediate displays were requested
if(statFreq.n > 0) { if(statFreq.n > 0) {
double totTime = totTimer->elapsed(); double totTime = totTimer->elapsed();
LOG(info, LOG(info,
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s", "Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime); totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
} }
} }
@ -288,6 +291,14 @@ public:
auto devices = Config::getDevices(options_); auto devices = Config::getDevices(options_);
numDevices_ = devices.size(); numDevices_ = devices.size();
// preload models
std::vector<std::vector<io::Item>> model_items_;
auto models = options->get<std::vector<std::string>>("models");
for(auto model : models) {
auto items = io::loadItems(model);
model_items_.push_back(std::move(items));
}
// initialize scorers // initialize scorers
for(auto device : devices) { for(auto device : devices) {
auto graph = New<ExpressionGraph>(true); auto graph = New<ExpressionGraph>(true);
@ -303,7 +314,7 @@ public:
graph->reserveWorkspaceMB(options_->get<size_t>("workspace")); graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph); graphs_.push_back(graph);
auto scorers = createScorers(options_); auto scorers = createScorers(options_, model_items_);
for(auto scorer : scorers) { for(auto scorer : scorers) {
scorer->init(graph); scorer->init(graph);
if(shortlistGenerator_) if(shortlistGenerator_)