mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Scorer model loading (#860)
* Add MMAP as an option * Use io::isBin * Allow getYamlFromModel from an Item vector * ScorerWrapper can now load on to a graph from Item vector The interface IEncoderDecoder can now call graph loads directly from an Item Vector. * Translator loads model before creating scorers Scorers are created from an Item vector * Replace model-config try-catch with check using IsNull * Prefer empty vs size * load by items should be pure virtual * Stepwise forward load to encdec * nematus can load from items * amun can load from items * loadItems in TranslateService * Remove logging * Remove by filename scorer functions * Replace by filename createScorer * Explicitly provide default value for get model-mmap * CLI option for model-mmap only for translation and CPU compile * Ensure model-mmap option is CPU only * Remove move on temporary object * Reinstate log messages for model loading in Amun / Nematus * Add log messages for model loading in scorers Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
This commit is contained in:
parent
c84599d08a
commit
b29cc07a95
@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
|
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
|
||||||
- Add unit tests for binary files.
|
- Add unit tests for binary files.
|
||||||
- Fix compilation with OMP
|
- Fix compilation with OMP
|
||||||
|
- Added `--model-mmap` option to enable mmap loading for CPU-based translation
|
||||||
- Compute aligned memory sizes using exact sizing
|
- Compute aligned memory sizes using exact sizing
|
||||||
- Support for loading lexical shortlist from a binary blob
|
- Support for loading lexical shortlist from a binary blob
|
||||||
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
|
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
|
||||||
|
@ -183,7 +183,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||||||
"Path prefix for pre-trained model to initialize model weights");
|
"Path prefix for pre-trained model to initialize model weights");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#ifdef COMPILE_CPU
|
||||||
|
if(mode_ == cli::mode::translation) {
|
||||||
|
cli.add<bool>("--model-mmap",
|
||||||
|
"Use memory-mapping when loading model (CPU only)");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
cli.add<bool>("--ignore-model-config",
|
cli.add<bool>("--ignore-model-config",
|
||||||
"Ignore the model configuration saved in npz file");
|
"Ignore the model configuration saved in npz file");
|
||||||
cli.add<std::string>("--type",
|
cli.add<std::string>("--type",
|
||||||
|
@ -54,6 +54,9 @@ void ConfigValidator::validateOptionsTranslation() const {
|
|||||||
ABORT_IF(models.empty() && configs.empty(),
|
ABORT_IF(models.empty() && configs.empty(),
|
||||||
"You need to provide at least one model file or a config file");
|
"You need to provide at least one model file or a config file");
|
||||||
|
|
||||||
|
ABORT_IF(get<bool>("model-mmap") && get<size_t>("cpu-threads") == 0,
|
||||||
|
"Model MMAP is CPU-only, please use --cpu-threads");
|
||||||
|
|
||||||
for(const auto& modelFile : models) {
|
for(const auto& modelFile : models) {
|
||||||
filesystem::Path modelPath(modelFile);
|
filesystem::Path modelPath(modelFile);
|
||||||
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
|
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
|
||||||
|
@ -56,6 +56,18 @@ void getYamlFromModel(YAML::Node& yaml,
|
|||||||
yaml = YAML::Load(item.data());
|
yaml = YAML::Load(item.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load YAML from item
|
||||||
|
void getYamlFromModel(YAML::Node& yaml,
|
||||||
|
const std::string& varName,
|
||||||
|
const std::vector<Item>& items) {
|
||||||
|
for(auto& item : items) {
|
||||||
|
if(item.name == varName) {
|
||||||
|
yaml = YAML::Load(item.data());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void addMetaToItems(const std::string& meta,
|
void addMetaToItems(const std::string& meta,
|
||||||
const std::string& varName,
|
const std::string& varName,
|
||||||
std::vector<io::Item>& items) {
|
std::vector<io::Item>& items) {
|
||||||
|
@ -21,6 +21,7 @@ bool isBin(const std::string& fileName);
|
|||||||
|
|
||||||
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
|
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::string& fileName);
|
||||||
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
|
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const void* ptr);
|
||||||
|
void getYamlFromModel(YAML::Node& yaml, const std::string& varName, const std::vector<Item>& items);
|
||||||
|
|
||||||
void addMetaToItems(const std::string& meta,
|
void addMetaToItems(const std::string& meta,
|
||||||
const std::string& varName,
|
const std::string& varName,
|
||||||
|
@ -739,7 +739,7 @@ public:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
/** Load model (mainly parameter objects) from array of io::Items */
|
/** Load model (mainly parameter objects) from array of io::Items */
|
||||||
void load(std::vector<io::Item>& ioItems, bool markReloaded = true) {
|
void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
|
||||||
setReloaded(false);
|
setReloaded(false);
|
||||||
for(auto& item : ioItems) {
|
for(auto& item : ioItems) {
|
||||||
std::string pName = item.name;
|
std::string pName = item.name;
|
||||||
|
@ -36,7 +36,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void load(Ptr<ExpressionGraph> graph,
|
void load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::vector<io::Item>& items,
|
||||||
bool /*markedReloaded*/ = true) override {
|
bool /*markedReloaded*/ = true) override {
|
||||||
std::map<std::string, std::string> nameMap
|
std::map<std::string, std::string> nameMap
|
||||||
= {{"decoder_U", "decoder_cell1_U"},
|
= {{"decoder_U", "decoder_cell1_U"},
|
||||||
@ -89,9 +89,7 @@ public:
|
|||||||
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
|
if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
|
||||||
nameMap["Wemb"] = "Wemb";
|
nameMap["Wemb"] = "Wemb";
|
||||||
|
|
||||||
LOG(info, "Loading model from {}", name);
|
auto ioItems = items;
|
||||||
// load items from .npz file
|
|
||||||
auto ioItems = io::loadItems(name);
|
|
||||||
// map names and remove a dummy matrices
|
// map names and remove a dummy matrices
|
||||||
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
||||||
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
||||||
@ -120,6 +118,14 @@ public:
|
|||||||
graph->load(ioItems);
|
graph->load(ioItems);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool /*markReloaded*/ = true) override {
|
||||||
|
LOG(info, "Loading model from {}", name);
|
||||||
|
auto ioItems = io::loadItems(name);
|
||||||
|
load(graph, ioItems);
|
||||||
|
}
|
||||||
|
|
||||||
void save(Ptr<ExpressionGraph> graph,
|
void save(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool saveTranslatorConfig = false) override {
|
bool saveTranslatorConfig = false) override {
|
||||||
|
@ -325,6 +325,12 @@ protected:
|
|||||||
public:
|
public:
|
||||||
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
|
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
|
||||||
|
|
||||||
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::vector<io::Item>& items,
|
||||||
|
bool markedReloaded = true) override {
|
||||||
|
encdec_->load(graph, items, markedReloaded);
|
||||||
|
}
|
||||||
|
|
||||||
virtual void load(Ptr<ExpressionGraph> graph,
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool markedReloaded = true) override {
|
bool markedReloaded = true) override {
|
||||||
|
@ -144,6 +144,12 @@ std::string EncoderDecoder::getModelParametersAsString() {
|
|||||||
return std::string(out.c_str());
|
return std::string(out.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::vector<io::Item>& items,
|
||||||
|
bool markedReloaded) {
|
||||||
|
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||||
|
}
|
||||||
|
|
||||||
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
|
void EncoderDecoder::load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool markedReloaded) {
|
bool markedReloaded) {
|
||||||
|
@ -12,6 +12,11 @@ namespace marian {
|
|||||||
class IEncoderDecoder : public models::IModel {
|
class IEncoderDecoder : public models::IModel {
|
||||||
public:
|
public:
|
||||||
virtual ~IEncoderDecoder() {}
|
virtual ~IEncoderDecoder() {}
|
||||||
|
|
||||||
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::vector<io::Item>& items,
|
||||||
|
bool markedReloaded = true) = 0;
|
||||||
|
|
||||||
virtual void load(Ptr<ExpressionGraph> graph,
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool markedReloaded = true) override
|
bool markedReloaded = true) override
|
||||||
@ -91,6 +96,10 @@ public:
|
|||||||
|
|
||||||
void push_back(Ptr<DecoderBase> decoder);
|
void push_back(Ptr<DecoderBase> decoder);
|
||||||
|
|
||||||
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::vector<io::Item>& items,
|
||||||
|
bool markedReloaded = true) override;
|
||||||
|
|
||||||
virtual void load(Ptr<ExpressionGraph> graph,
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool markedReloaded = true) override;
|
bool markedReloaded = true) override;
|
||||||
|
@ -26,11 +26,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void load(Ptr<ExpressionGraph> graph,
|
void load(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::vector<io::Item>& items,
|
||||||
bool /*markReloaded*/ = true) override {
|
bool /*markReloaded*/ = true) override {
|
||||||
LOG(info, "Loading model from {}", name);
|
auto ioItems = items;
|
||||||
// load items from .npz file
|
|
||||||
auto ioItems = io::loadItems(name);
|
|
||||||
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
|
// map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
|
||||||
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
for(auto it = ioItems.begin(); it != ioItems.end();) {
|
||||||
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
// for backwards compatibility, turn one-dimensional vector into two dimensional matrix with first dimension being 1 and second dimension of the original size
|
||||||
@ -41,7 +39,7 @@ public:
|
|||||||
it->shape.set(0, 1);
|
it->shape.set(0, 1);
|
||||||
it->shape.set(1, dim);
|
it->shape.set(1, dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(it->name == "decoder_c_tt") {
|
if(it->name == "decoder_c_tt") {
|
||||||
it = ioItems.erase(it);
|
it = ioItems.erase(it);
|
||||||
} else if(it->name == "uidx") {
|
} else if(it->name == "uidx") {
|
||||||
@ -59,6 +57,14 @@ public:
|
|||||||
graph->load(ioItems);
|
graph->load(ioItems);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool /*markReloaded*/ = true) override {
|
||||||
|
LOG(info, "Loading model from {}", name);
|
||||||
|
auto ioItems = io::loadItems(name);
|
||||||
|
load(graph, ioItems);
|
||||||
|
}
|
||||||
|
|
||||||
void save(Ptr<ExpressionGraph> graph,
|
void save(Ptr<ExpressionGraph> graph,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
bool saveTranslatorConfig = false) override {
|
bool saveTranslatorConfig = false) override {
|
||||||
|
@ -5,7 +5,7 @@ namespace marian {
|
|||||||
|
|
||||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||||
float weight,
|
float weight,
|
||||||
const std::string& model,
|
std::vector<io::Item> items,
|
||||||
Ptr<Options> options) {
|
Ptr<Options> options) {
|
||||||
options->set("inference", true);
|
options->set("inference", true);
|
||||||
std::string type = options->get<std::string>("type");
|
std::string type = options->get<std::string>("type");
|
||||||
@ -22,7 +22,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
|
|||||||
|
|
||||||
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
|
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
|
||||||
|
|
||||||
return New<ScorerWrapper>(encdec, fname, weight, model);
|
return New<ScorerWrapper>(encdec, fname, weight, items);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||||
@ -47,30 +47,30 @@ Ptr<Scorer> scorerByType(const std::string& fname,
|
|||||||
return New<ScorerWrapper>(encdec, fname, weight, ptr);
|
return New<ScorerWrapper>(encdec, fname, weight, ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models) {
|
||||||
std::vector<Ptr<Scorer>> scorers;
|
std::vector<Ptr<Scorer>> scorers;
|
||||||
|
|
||||||
auto models = options->get<std::vector<std::string>>("models");
|
|
||||||
|
|
||||||
std::vector<float> weights(models.size(), 1.f);
|
std::vector<float> weights(models.size(), 1.f);
|
||||||
if(options->hasAndNotEmpty("weights"))
|
if(options->hasAndNotEmpty("weights"))
|
||||||
weights = options->get<std::vector<float>>("weights");
|
weights = options->get<std::vector<float>>("weights");
|
||||||
|
|
||||||
bool isPrevRightLeft = false; // if the previous model was a right-to-left model
|
bool isPrevRightLeft = false; // if the previous model was a right-to-left model
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for(auto model : models) {
|
for(auto items : models) {
|
||||||
std::string fname = "F" + std::to_string(i);
|
std::string fname = "F" + std::to_string(i);
|
||||||
|
|
||||||
// load options specific for the scorer
|
// load options specific for the scorer
|
||||||
auto modelOptions = New<Options>(options->clone());
|
auto modelOptions = New<Options>(options->clone());
|
||||||
try {
|
if(!options->get<bool>("ignore-model-config")) {
|
||||||
if(!options->get<bool>("ignore-model-config")) {
|
YAML::Node modelYaml;
|
||||||
YAML::Node modelYaml;
|
io::getYamlFromModel(modelYaml, "special:model.yml", items);
|
||||||
io::getYamlFromModel(modelYaml, "special:model.yml", model);
|
if(!modelYaml.IsNull()) {
|
||||||
|
LOG(info, "Loaded model config");
|
||||||
modelOptions->merge(modelYaml, true);
|
modelOptions->merge(modelYaml, true);
|
||||||
}
|
}
|
||||||
} catch(std::runtime_error&) {
|
else {
|
||||||
LOG(warn, "No model settings found in model file");
|
LOG(warn, "No model settings found in model file");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// l2r and r2l cannot be used in the same ensemble
|
// l2r and r2l cannot be used in the same ensemble
|
||||||
@ -85,13 +85,24 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
|
scorers.push_back(scorerByType(fname, weights[i], items, modelOptions));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
return scorers;
|
return scorers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
|
||||||
|
std::vector<std::vector<io::Item>> model_items;
|
||||||
|
auto models = options->get<std::vector<std::string>>("models");
|
||||||
|
for(auto model : models) {
|
||||||
|
auto items = io::loadItems(model);
|
||||||
|
model_items.push_back(std::move(items));
|
||||||
|
}
|
||||||
|
|
||||||
|
return createScorers(options, model_items);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
|
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs) {
|
||||||
std::vector<Ptr<Scorer>> scorers;
|
std::vector<Ptr<Scorer>> scorers;
|
||||||
|
|
||||||
@ -105,14 +116,16 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
|
|||||||
|
|
||||||
// load options specific for the scorer
|
// load options specific for the scorer
|
||||||
auto modelOptions = New<Options>(options->clone());
|
auto modelOptions = New<Options>(options->clone());
|
||||||
try {
|
if(!options->get<bool>("ignore-model-config")) {
|
||||||
if(!options->get<bool>("ignore-model-config")) {
|
YAML::Node modelYaml;
|
||||||
YAML::Node modelYaml;
|
io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
|
||||||
io::getYamlFromModel(modelYaml, "special:model.yml", ptr);
|
if(!modelYaml.IsNull()) {
|
||||||
|
LOG(info, "Loaded model config");
|
||||||
modelOptions->merge(modelYaml, true);
|
modelOptions->merge(modelYaml, true);
|
||||||
}
|
}
|
||||||
} catch(std::runtime_error&) {
|
else {
|
||||||
LOG(warn, "No model settings found in model file");
|
LOG(warn, "No model settings found in model file");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));
|
scorers.push_back(scorerByType(fname, weights[i], ptr, modelOptions));
|
||||||
|
@ -73,9 +73,19 @@ class ScorerWrapper : public Scorer {
|
|||||||
private:
|
private:
|
||||||
Ptr<IEncoderDecoder> encdec_;
|
Ptr<IEncoderDecoder> encdec_;
|
||||||
std::string fname_;
|
std::string fname_;
|
||||||
|
std::vector<io::Item> items_;
|
||||||
const void* ptr_;
|
const void* ptr_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
ScorerWrapper(Ptr<models::IModel> encdec,
|
||||||
|
const std::string& name,
|
||||||
|
float weight,
|
||||||
|
std::vector<io::Item>& items)
|
||||||
|
: Scorer(name, weight),
|
||||||
|
encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
|
||||||
|
items_(items),
|
||||||
|
ptr_{0} {}
|
||||||
|
|
||||||
ScorerWrapper(Ptr<models::IModel> encdec,
|
ScorerWrapper(Ptr<models::IModel> encdec,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
float weight,
|
float weight,
|
||||||
@ -97,7 +107,9 @@ public:
|
|||||||
|
|
||||||
virtual void init(Ptr<ExpressionGraph> graph) override {
|
virtual void init(Ptr<ExpressionGraph> graph) override {
|
||||||
graph->switchParams(getName());
|
graph->switchParams(getName());
|
||||||
if(ptr_)
|
if(!items_.empty())
|
||||||
|
encdec_->load(graph, items_);
|
||||||
|
else if(ptr_)
|
||||||
encdec_->mmap(graph, ptr_);
|
encdec_->mmap(graph, ptr_);
|
||||||
else
|
else
|
||||||
encdec_->load(graph, fname_);
|
encdec_->load(graph, fname_);
|
||||||
@ -142,12 +154,19 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||||
|
float weight,
|
||||||
|
std::vector<io::Item> items,
|
||||||
|
Ptr<Options> options);
|
||||||
|
|
||||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||||
float weight,
|
float weight,
|
||||||
const std::string& model,
|
const std::string& model,
|
||||||
Ptr<Options> config);
|
Ptr<Options> config);
|
||||||
|
|
||||||
|
|
||||||
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
|
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options);
|
||||||
|
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<std::vector<io::Item>> models);
|
||||||
|
|
||||||
Ptr<Scorer> scorerByType(const std::string& fname,
|
Ptr<Scorer> scorerByType(const std::string& fname,
|
||||||
float weight,
|
float weight,
|
||||||
|
@ -20,12 +20,7 @@
|
|||||||
#include "translator/scorers.h"
|
#include "translator/scorers.h"
|
||||||
|
|
||||||
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
|
// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
|
||||||
// @TODO: add this as an actual feature.
|
|
||||||
#define MMAP 0
|
|
||||||
|
|
||||||
#if MMAP
|
|
||||||
#include "3rd_party/mio/mio.hpp"
|
#include "3rd_party/mio/mio.hpp"
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
@ -42,9 +37,8 @@ private:
|
|||||||
|
|
||||||
size_t numDevices_;
|
size_t numDevices_;
|
||||||
|
|
||||||
#if MMAP
|
std::vector<mio::mmap_source> model_mmaps_; // map
|
||||||
std::vector<mio::mmap_source> mmaps_;
|
std::vector<std::vector<io::Item>> model_items_; // non-mmap
|
||||||
#endif
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Translate(Ptr<Options> options)
|
Translate(Ptr<Options> options)
|
||||||
@ -76,15 +70,21 @@ public:
|
|||||||
scorers_.resize(numDevices_);
|
scorers_.resize(numDevices_);
|
||||||
graphs_.resize(numDevices_);
|
graphs_.resize(numDevices_);
|
||||||
|
|
||||||
#if MMAP
|
|
||||||
auto models = options->get<std::vector<std::string>>("models");
|
auto models = options->get<std::vector<std::string>>("models");
|
||||||
for(auto model : models) {
|
if(options_->get<bool>("model-mmap", false)) {
|
||||||
marian::filesystem::Path modelPath(model);
|
for(auto model : models) {
|
||||||
ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
|
ABORT_IF(!io::isBin(model), "Non-binarized models cannot be mmapped");
|
||||||
"Non-binarized models cannot be mmapped");
|
LOG(info, "Loading model from {}", model);
|
||||||
mmaps_.push_back(std::move(mio::mmap_source(model)));
|
model_mmaps_.push_back(mio::mmap_source(model));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(auto model : models) {
|
||||||
|
LOG(info, "Loading model from {}", model);
|
||||||
|
auto items = io::loadItems(model);
|
||||||
|
model_items_.push_back(std::move(items));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
for(auto device : devices) {
|
for(auto device : devices) {
|
||||||
@ -101,11 +101,14 @@ public:
|
|||||||
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||||
graphs_[id] = graph;
|
graphs_[id] = graph;
|
||||||
|
|
||||||
#if MMAP
|
std::vector<Ptr<Scorer>> scorers;
|
||||||
auto scorers = createScorers(options_, mmaps_);
|
if(options_->get<bool>("model-mmap", false)) {
|
||||||
#else
|
scorers = createScorers(options_, model_mmaps_);
|
||||||
auto scorers = createScorers(options_);
|
}
|
||||||
#endif
|
else {
|
||||||
|
scorers = createScorers(options_, model_items_);
|
||||||
|
}
|
||||||
|
|
||||||
for(auto scorer : scorers) {
|
for(auto scorer : scorers) {
|
||||||
scorer->init(graph);
|
scorer->init(graph);
|
||||||
if(shortlistGenerator_)
|
if(shortlistGenerator_)
|
||||||
@ -146,11 +149,11 @@ public:
|
|||||||
std::mutex syncCounts;
|
std::mutex syncCounts;
|
||||||
|
|
||||||
// timer and counters for total elapsed time and statistics
|
// timer and counters for total elapsed time and statistics
|
||||||
std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
|
std::unique_ptr<timer::Timer> totTimer(new timer::Timer());
|
||||||
size_t totBatches = 0;
|
size_t totBatches = 0;
|
||||||
size_t totLines = 0;
|
size_t totLines = 0;
|
||||||
size_t totSourceTokens = 0;
|
size_t totSourceTokens = 0;
|
||||||
|
|
||||||
// timer and counters for elapsed time and statistics between updates
|
// timer and counters for elapsed time and statistics between updates
|
||||||
std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
|
std::unique_ptr<timer::Timer> curTimer(new timer::Timer());
|
||||||
size_t curBatches = 0;
|
size_t curBatches = 0;
|
||||||
@ -176,7 +179,7 @@ public:
|
|||||||
bg.prepare();
|
bg.prepare();
|
||||||
for(auto batch : bg) {
|
for(auto batch : bg) {
|
||||||
auto task = [=, &syncCounts,
|
auto task = [=, &syncCounts,
|
||||||
&totBatches, &totLines, &totSourceTokens, &totTimer,
|
&totBatches, &totLines, &totSourceTokens, &totTimer,
|
||||||
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
|
&curBatches, &curLines, &curSourceTokens, &curTimer](size_t id) {
|
||||||
thread_local Ptr<ExpressionGraph> graph;
|
thread_local Ptr<ExpressionGraph> graph;
|
||||||
thread_local std::vector<Ptr<Scorer>> scorers;
|
thread_local std::vector<Ptr<Scorer>> scorers;
|
||||||
@ -200,12 +203,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if we asked for speed information display this
|
// if we asked for speed information display this
|
||||||
if(statFreq.n > 0) {
|
if(statFreq.n > 0) {
|
||||||
std::lock_guard<std::mutex> lock(syncCounts);
|
std::lock_guard<std::mutex> lock(syncCounts);
|
||||||
totBatches++;
|
totBatches++;
|
||||||
totLines += batch->size();
|
totLines += batch->size();
|
||||||
totSourceTokens += batch->front()->batchWords();
|
totSourceTokens += batch->front()->batchWords();
|
||||||
|
|
||||||
curBatches++;
|
curBatches++;
|
||||||
curLines += batch->size();
|
curLines += batch->size();
|
||||||
curSourceTokens += batch->front()->batchWords();
|
curSourceTokens += batch->front()->batchWords();
|
||||||
@ -214,10 +217,10 @@ public:
|
|||||||
double totTime = totTimer->elapsed();
|
double totTime = totTimer->elapsed();
|
||||||
double curTime = curTimer->elapsed();
|
double curTime = curTimer->elapsed();
|
||||||
|
|
||||||
LOG(info,
|
LOG(info,
|
||||||
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
|
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (since last): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
|
||||||
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
|
totBatches, totLines, totSourceTokens, totTime, curBatches / curTime, curLines / curTime, curSourceTokens / curTime);
|
||||||
|
|
||||||
// reset stats between updates
|
// reset stats between updates
|
||||||
curBatches = curLines = curSourceTokens = 0;
|
curBatches = curLines = curSourceTokens = 0;
|
||||||
curTimer.reset(new timer::Timer());
|
curTimer.reset(new timer::Timer());
|
||||||
@ -230,12 +233,12 @@ public:
|
|||||||
|
|
||||||
// make sure threads are joined before other local variables get de-allocated
|
// make sure threads are joined before other local variables get de-allocated
|
||||||
threadPool.join_all();
|
threadPool.join_all();
|
||||||
|
|
||||||
// display final speed numbers over total translation if intermediate displays were requested
|
// display final speed numbers over total translation if intermediate displays were requested
|
||||||
if(statFreq.n > 0) {
|
if(statFreq.n > 0) {
|
||||||
double totTime = totTimer->elapsed();
|
double totTime = totTimer->elapsed();
|
||||||
LOG(info,
|
LOG(info,
|
||||||
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
|
"Processed {} batches, {} lines, {} source tokens in {:.2f}s - Speed (total): {:.2f} batches/s - {:.2f} lines/s - {:.2f} tokens/s",
|
||||||
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
|
totBatches, totLines, totSourceTokens, totTime, totBatches / totTime, totLines / totTime, totSourceTokens / totTime);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,6 +291,14 @@ public:
|
|||||||
auto devices = Config::getDevices(options_);
|
auto devices = Config::getDevices(options_);
|
||||||
numDevices_ = devices.size();
|
numDevices_ = devices.size();
|
||||||
|
|
||||||
|
// preload models
|
||||||
|
std::vector<std::vector<io::Item>> model_items_;
|
||||||
|
auto models = options->get<std::vector<std::string>>("models");
|
||||||
|
for(auto model : models) {
|
||||||
|
auto items = io::loadItems(model);
|
||||||
|
model_items_.push_back(std::move(items));
|
||||||
|
}
|
||||||
|
|
||||||
// initialize scorers
|
// initialize scorers
|
||||||
for(auto device : devices) {
|
for(auto device : devices) {
|
||||||
auto graph = New<ExpressionGraph>(true);
|
auto graph = New<ExpressionGraph>(true);
|
||||||
@ -303,7 +314,7 @@ public:
|
|||||||
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||||
graphs_.push_back(graph);
|
graphs_.push_back(graph);
|
||||||
|
|
||||||
auto scorers = createScorers(options_);
|
auto scorers = createScorers(options_, model_items_);
|
||||||
for(auto scorer : scorers) {
|
for(auto scorer : scorers) {
|
||||||
scorer->init(graph);
|
scorer->init(graph);
|
||||||
if(shortlistGenerator_)
|
if(shortlistGenerator_)
|
||||||
|
Loading…
Reference in New Issue
Block a user