diff --git a/3rd_party/marian-dev b/3rd_party/marian-dev index 94aeaa4..ca15d61 160000 --- a/3rd_party/marian-dev +++ b/3rd_party/marian-dev @@ -1 +1 @@ -Subproject commit 94aeaa4616a0fb01ac95a23f0e74a214a94e7609 +Subproject commit ca15d61c87ef2f8f2c290b75a5da6236eb9833d2 diff --git a/app/service-cli.cpp b/app/service-cli.cpp index e9dc433..0e958d6 100644 --- a/app/service-cli.cpp +++ b/app/service-cli.cpp @@ -18,15 +18,17 @@ int main(int argc, char *argv[]) { // Prepare memories for model and shortlist marian::bergamot::AlignedMemory modelBytes, shortlistBytes; + std::vector> vocabsBytes; if (options->get("check-bytearray")) { // Load legit values into bytearrays. modelBytes = marian::bergamot::getModelMemoryFromConfig(options); shortlistBytes = marian::bergamot::getShortlistMemoryFromConfig(options); + marian::bergamot::getVocabsMemoryFromConfig(options, vocabsBytes); } marian::bergamot::Service service(options, std::move(modelBytes), - std::move(shortlistBytes)); + std::move(shortlistBytes), std::move(vocabsBytes)); // Read a large input text blob from stdin std::ostringstream std_input; diff --git a/src/translator/batch_translator.cpp b/src/translator/batch_translator.cpp index 6b2425d..c627172 100644 --- a/src/translator/batch_translator.cpp +++ b/src/translator/batch_translator.cpp @@ -4,6 +4,7 @@ #include "data/corpus.h" #include "data/text_input.h" #include "translator/beam_search.h" +#include "byte_array_util.h" namespace marian { namespace bergamot { @@ -18,11 +19,11 @@ BatchTranslator::BatchTranslator(DeviceId const device, void BatchTranslator::initialize() { // Initializes the graph. + bool check = options_->get("check-bytearray",false); // Flag holds whether validate the bytearray (model and shortlist) if (options_->hasAndNotEmpty("shortlist")) { int srcIdx = 0, trgIdx = 1; bool shared_vcb = vocabs_->front() == vocabs_->back(); if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) { - bool check = options_->get("check-bytearray",true); slgen_ = New(shortlistMemory_->begin(), shortlistMemory_->size(), vocabs_->front(), vocabs_->back(), srcIdx, trgIdx, shared_vcb, check); @@ -45,6 +46,10 @@ void BatchTranslator::initialize() { if (modelMemory_->size() > 0 && modelMemory_->begin() != nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model from there, as opposed to from reading in the config file ABORT_IF((uintptr_t)modelMemory_->begin() % 256 != 0, "The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it."); + if (check) { + ABORT_IF(!validateBinaryModel(*modelMemory_, modelMemory_->size()), + "The binary file is invalid. Incomplete or corrupted download?"); + } const std::vector container = {modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector. However we will only ever use 1 during decoding. scorers_ = createScorers(options_, container); } else { diff --git a/src/translator/byte_array_util.cpp b/src/translator/byte_array_util.cpp index c3bf7cc..00beaa6 100644 --- a/src/translator/byte_array_util.cpp +++ b/src/translator/byte_array_util.cpp @@ -1,12 +1,12 @@ #include "byte_array_util.h" #include #include +#include namespace marian { namespace bergamot { namespace { - // This is a basic validator that checks if the file has not been truncated // it basically loads up the header and checks @@ -26,9 +26,10 @@ const T* get(const void*& current, uint64_t num = 1) { current = (const T*)current + num; return ptr; } +} // Anonymous namespace -bool validateBinaryModel(AlignedMemory& model, uint64_t fileSize) { - const void * current = &model[0]; +bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize) { + const void * current = model.begin(); uint64_t memoryNeeded = sizeof(uint64_t)*2; // We keep track of how much memory we would need if we have a complete file uint64_t numHeaders; if (fileSize >= memoryNeeded) { // We have enough filesize to fetch the headers. @@ -76,8 +77,6 @@ bool validateBinaryModel(AlignedMemory& model, uint64_t fileSize) { } } -} // Anonymous namespace - AlignedMemory loadFileToMemory(const std::string& path, size_t alignment){ uint64_t fileSize = filesystem::fileSize(path); io::InputFileStream in(path); @@ -89,13 +88,12 @@ AlignedMemory loadFileToMemory(const std::string& path, size_t alignment){ } AlignedMemory getModelMemoryFromConfig(marian::Ptr options){ - auto models = options->get>("models"); - ABORT_IF(models.size() != 1, "Loading multiple binary models is not supported for now as it is not necessary."); - marian::filesystem::Path modelPath(models[0]); - ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"), "The file of binary model should end with .bin"); - AlignedMemory alignedMemory = loadFileToMemory(models[0], 256); - ABORT_IF(!validateBinaryModel(alignedMemory, alignedMemory.size()), "The binary file is invalid. Incomplete or corrupted download?"); - return alignedMemory; + auto models = options->get>("models"); + ABORT_IF(models.size() != 1, "Loading multiple binary models is not supported for now as it is not necessary."); + marian::filesystem::Path modelPath(models[0]); + ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"), "The file of binary model should end with .bin"); + AlignedMemory alignedMemory = loadFileToMemory(models[0], 256); + return alignedMemory; } AlignedMemory getShortlistMemoryFromConfig(marian::Ptr options){ @@ -104,5 +102,20 @@ AlignedMemory getShortlistMemoryFromConfig(marian::Ptr options) return loadFileToMemory(shortlist[0], 64); } +void getVocabsMemoryFromConfig(marian::Ptr options, + std::vector>& vocabMemories){ + auto vfiles = options->get>("vocabs"); + ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies."); + vocabMemories.resize(vfiles.size()); + std::unordered_map> vocabMap; + for (size_t i = 0; i < vfiles.size(); ++i) { + auto m = vocabMap.emplace(std::make_pair(vfiles[i], std::shared_ptr())); + if (m.second) { + m.first->second = std::make_shared(loadFileToMemory(vfiles[i], 64)); + } + vocabMemories[i] = m.first->second; + } +} + } // namespace bergamot } // namespace marian diff --git a/src/translator/byte_array_util.h b/src/translator/byte_array_util.h index a8df1cb..3cbf3d3 100644 --- a/src/translator/byte_array_util.h +++ b/src/translator/byte_array_util.h @@ -7,6 +7,8 @@ namespace bergamot { AlignedMemory loadFileToMemory(const std::string& path, size_t alignment); AlignedMemory getModelMemoryFromConfig(marian::Ptr options); AlignedMemory getShortlistMemoryFromConfig(marian::Ptr options); - +void getVocabsMemoryFromConfig(marian::Ptr options, + std::vector>& vocabMemories); +bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize); } // namespace bergamot } // namespace marian diff --git a/src/translator/service.cpp b/src/translator/service.cpp index 3d19f5e..385a2a5 100644 --- a/src/translator/service.cpp +++ b/src/translator/service.cpp @@ -6,21 +6,34 @@ #include inline std::vector> -loadVocabularies(marian::Ptr options) { +loadVocabularies(marian::Ptr options, + std::vector>&& vocabMemories) { // @TODO: parallelize vocab loading for faster startup - auto vfiles = options->get>("vocabs"); - // with the current setup, we need at least two vocabs: src and trg - ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies."); - std::vector> vocabs(vfiles.size()); - std::unordered_map> vmap; - for (size_t i = 0; i < vocabs.size(); ++i) { - auto m = - vmap.emplace(std::make_pair(vfiles[i], marian::Ptr())); - if (m.second) { // new: load the vocab - m.first->second = marian::New(options, i); - m.first->second->load(vfiles[i]); + std::vector> vocabs; + if(!vocabMemories.empty()){ + // load vocabs from buffer + ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies."); + vocabs.resize(vocabMemories.size()); + for (size_t i = 0; i < vocabs.size(); i++) { + marian::Ptr vocab = marian::New(options, i); + vocab->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size())); + vocabs[i] = vocab; + } + } else { + // load vocabs from file + auto vfiles = options->get>("vocabs"); + // with the current setup, we need at least two vocabs: src and trg + ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies."); + vocabs.resize(vfiles.size()); + std::unordered_map> vmap; + for (size_t i = 0; i < vocabs.size(); ++i) { + auto m = vmap.emplace(std::make_pair(vfiles[i], marian::Ptr())); + if (m.second) { // new: load the vocab + m.first->second = marian::New(options, i); + m.first->second->load(vfiles[i]); + } + vocabs[i] = m.first->second; } - vocabs[i] = m.first->second; } return vocabs; } @@ -28,11 +41,14 @@ loadVocabularies(marian::Ptr options) { namespace marian { namespace bergamot { -Service::Service(Ptr options, AlignedMemory modelMemory, AlignedMemory shortlistMemory) - : requestId_(0), options_(options), vocabs_(std::move(loadVocabularies(options))), +Service::Service(Ptr options, AlignedMemory modelMemory, AlignedMemory shortlistMemory, + std::vector> vocabMemories) + : requestId_(0), options_(options), + vocabs_(std::move(loadVocabularies(options, std::move(vocabMemories)))), text_processor_(vocabs_, options), batcher_(options), numWorkers_(options->get("cpu-threads")), - modelMemory_(std::move(modelMemory)), shortlistMemory_(std::move(shortlistMemory)) + modelMemory_(std::move(modelMemory)), + shortlistMemory_(std::move(shortlistMemory)) #ifndef WASM_COMPATIBLE_SOURCE // 0 elements in PCQueue is illegal and can lead to failures. Adding a // guard to have at least one entry allocated. In the single-threaded diff --git a/src/translator/service.h b/src/translator/service.h index 288c649..721d436 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -64,10 +64,12 @@ class Service { public: /// @param options Marian options object /// @param modelMemory byte array (aligned to 256!!!) that contains the bytes - /// of a model.bin. Optional, defaults to nullptr when not used + /// of a model.bin. /// @param shortlistMemory byte array of shortlist (aligned to 64) + /// @param vocabMemories vector of vocabulary memories (aligned to 64) explicit Service(Ptr options, AlignedMemory modelMemory, - AlignedMemory shortlistMemory); + AlignedMemory shortlistMemory, + std::vector> vocabMemories); /// Construct Service purely from Options. This expects options which /// marian-decoder expects to be set for loading model shortlist and @@ -76,24 +78,30 @@ public: /// /// This is equivalent to a call to: /// ```cpp - /// Service(options, AlignedMemory(), AlignedMemory()) + /// Service(options, AlignedMemory(), AlignedMemory(), {}) /// ``` /// wherein empty memory is passed and internal flow defaults to file-based - /// model, shortlist loading. + /// model, shortlist loading. AlignedMemory() corresponds to empty memory explicit Service(Ptr options) - : Service(options, AlignedMemory(), AlignedMemory()) {} + : Service(options, AlignedMemory(), AlignedMemory(), {}) {} /// Construct Service from a string configuration. /// @param [in] config string parsable as YAML expected to adhere with marian /// config - /// @param [in] model_memory byte array (aligned to 256!!!) that contains the - /// bytes of a model.bin. Optional. - /// @param [in] shortlistMemory byte array of shortlist (aligned to 64) + /// @param [in] modelMemory byte array (aligned to 256!!!) that contains the + /// bytes of a model.bin. Optional. AlignedMemory() corresponds to empty memory + /// @param [in] shortlistMemory byte array of shortlist (aligned to 64). Optional. + /// @param [in] vocabMemories vector of vocabulary memories (aligned to 64). Optional. + /// If two vocabularies are the same (based on the filenames), two entries (shared + /// pointers) will be generated which share the same AlignedMemory object. explicit Service(const std::string &config, AlignedMemory modelMemory = AlignedMemory(), - AlignedMemory shortlistMemory = AlignedMemory()) + AlignedMemory shortlistMemory = AlignedMemory(), + std::vector> vocabsMemories = {}) : Service(parseOptions(config, /*validate=*/false), - std::move(modelMemory), std::move(shortlistMemory)) {} + std::move(modelMemory), + std::move(shortlistMemory), + std::move(vocabsMemories)) {} /// Explicit destructor to clean up after any threads initialized in /// asynchronous operation mode. @@ -187,7 +195,6 @@ private: /// ordering among requests and logging/book-keeping. size_t requestId_; - /// Store vocabs representing source and target. std::vector> vocabs_; // ORDER DEPENDENCY (text_processor_)