diff --git a/src/translator/batch_translator.cpp b/src/translator/batch_translator.cpp index c627172..b35c4ce 100644 --- a/src/translator/batch_translator.cpp +++ b/src/translator/batch_translator.cpp @@ -10,11 +10,11 @@ namespace marian { namespace bergamot { BatchTranslator::BatchTranslator(DeviceId const device, - std::vector> &vocabs, + Vocabs &vocabs, Ptr options, const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory) - : device_(device), options_(options), vocabs_(&vocabs), + : device_(device), options_(options), vocabs_(vocabs), modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {} void BatchTranslator::initialize() { @@ -22,17 +22,17 @@ void BatchTranslator::initialize() { bool check = options_->get("check-bytearray",false); // Flag holds whether validate the bytearray (model and shortlist) if (options_->hasAndNotEmpty("shortlist")) { int srcIdx = 0, trgIdx = 1; - bool shared_vcb = vocabs_->front() == vocabs_->back(); + bool shared_vcb = vocabs_.sources().front() == vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) { slgen_ = New(shortlistMemory_->begin(), shortlistMemory_->size(), - vocabs_->front(), vocabs_->back(), - srcIdx, trgIdx, shared_vcb, check); + vocabs_.sources().front(), vocabs_.target(), + srcIdx, trgIdx, shared_vcb, check); } else { // Changed to BinaryShortlistGenerator to enable loading binary shortlist file // This class also supports text shortlist file - slgen_ = New(options_, vocabs_->front(), - vocabs_->back(), srcIdx, + slgen_ = New(options_, vocabs_.sources().front(), + vocabs_.target(), srcIdx, trgIdx, shared_vcb); } } @@ -97,7 +97,7 @@ void BatchTranslator::translate(Batch &batch) { std::vector> subBatches; for (size_t j = 0; j < maxDims.size(); ++j) { subBatches.emplace_back( - New(batchSize, maxDims[j], vocabs_->at(j))); + New(batchSize, maxDims[j], vocabs_.sources().at(j))); } std::vector words(maxDims.size(), 0); @@ -116,9 +116,8 @@ void BatchTranslator::translate(Batch &batch) { auto corpus_batch = Ptr(new CorpusBatch(subBatches)); corpus_batch->setSentenceIds(sentenceIds); - - auto trgVocab = vocabs_->back(); - auto search = New(options_, scorers_, trgVocab); + + auto search = New(options_, scorers_, vocabs_.target()); auto histories = std::move(search->search(graph_, corpus_batch)); batch.completeBatch(histories); diff --git a/src/translator/batch_translator.h b/src/translator/batch_translator.h index 761a534..048ba77 100644 --- a/src/translator/batch_translator.h +++ b/src/translator/batch_translator.h @@ -11,6 +11,7 @@ #include "request.h" #include "translator/history.h" #include "translator/scorers.h" +#include "vocabs.h" #ifndef WASM_COMPATIBLE_SOURCE #include "pcqueue.h" @@ -34,7 +35,7 @@ public: * @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not used. * @param shortlistMemory byte array of shortlist (aligned to 64) */ - explicit BatchTranslator(DeviceId const device, std::vector> &vocabs, + explicit BatchTranslator(DeviceId const device, Vocabs &vocabs, Ptr options, const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory); // convenience function for logging. TODO(jerin) @@ -45,7 +46,7 @@ public: private: Ptr options_; DeviceId device_; - std::vector> *vocabs_; + const Vocabs& vocabs_; Ptr graph_; std::vector> scorers_; Ptr slgen_; diff --git a/src/translator/definitions.h b/src/translator/definitions.h index 175397d..bf1cb57 100644 --- a/src/translator/definitions.h +++ b/src/translator/definitions.h @@ -28,27 +28,6 @@ struct MemoryBundle { /// @todo Not implemented yet AlignedMemory ssplitPrefixFile; - - MemoryBundle() = default; - - MemoryBundle(MemoryBundle &&from){ - model = std::move(from.model); - shortlist = std::move(from.shortlist); - vocabs = std::move(vocabs); - ssplitPrefixFile = std::move(from.ssplitPrefixFile); - } - - MemoryBundle &operator=(MemoryBundle &&from) { - model = std::move(from.model); - shortlist = std::move(from.shortlist); - vocabs = std::move(vocabs); - ssplitPrefixFile = std::move(from.ssplitPrefixFile); - return *this; - } - - // Delete copy constructors - MemoryBundle(const MemoryBundle&) = delete; - MemoryBundle& operator=(const MemoryBundle&) = delete; }; } // namespace bergamot diff --git a/src/translator/response_builder.cpp b/src/translator/response_builder.cpp index f68bd31..037d456 100644 --- a/src/translator/response_builder.cpp +++ b/src/translator/response_builder.cpp @@ -65,11 +65,10 @@ void ResponseBuilder::buildTranslatedText(Histories &histories, Result result = onebest[0]; // Expecting only one result; Words words = std::get<0>(result); - auto targetVocab = vocabs_->back(); std::string decoded; std::vector targetSentenceMappings; - targetVocab->decodeWithByteRanges(words, decoded, targetSentenceMappings); + vocabs_.target()->decodeWithByteRanges(words, decoded, targetSentenceMappings); switch (responseOptions_.concatStrategy) { case ConcatStrategy::FAITHFUL: { diff --git a/src/translator/response_builder.h b/src/translator/response_builder.h index 85caffb..b8a8dd4 100644 --- a/src/translator/response_builder.h +++ b/src/translator/response_builder.h @@ -4,6 +4,7 @@ #include "data/types.h" #include "response.h" #include "response_options.h" +#include "vocabs.h" // For now we will work with this, to avoid complaints another structure is hard // to operate with. @@ -24,10 +25,10 @@ public: /// @param [in] vocabs: marian vocab object (used in decoding) /// @param [in] promise: promise to set with the constructed Response. ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source, - std::vector> &vocabs, + Vocabs &vocabs, std::promise &&promise) : responseOptions_(responseOptions), source_(std::move(source)), - vocabs_(&vocabs), promise_(std::move(promise)) {} + vocabs_(vocabs), promise_(std::move(promise)) {} /// Constructs and sets the promise of a Response object from obtained /// histories after translating. @@ -81,7 +82,7 @@ private: // Data members are context/curried args for the functor. ResponseOptions responseOptions_; - std::vector> *vocabs_; // vocabs are required for decoding + const Vocabs& vocabs_; // vocabs are required for decoding // and any source validation checks. std::promise promise_; // To be set when callback triggered and // after Response constructed. diff --git a/src/translator/service.cpp b/src/translator/service.cpp index 16c4743..5439667 100644 --- a/src/translator/service.cpp +++ b/src/translator/service.cpp @@ -5,45 +5,12 @@ #include #include -inline std::vector> -loadVocabularies(marian::Ptr options, - std::vector>&& vocabMemories) { - // @TODO: parallelize vocab loading for faster startup - std::vector> vocabs; - if(!vocabMemories.empty()){ - // load vocabs from buffer - ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies."); - vocabs.resize(vocabMemories.size()); - for (size_t i = 0; i < vocabs.size(); i++) { - marian::Ptr vocab = marian::New(options, i); - vocab->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size())); - vocabs[i] = vocab; - } - } else { - // load vocabs from file - auto vfiles = options->get>("vocabs"); - // with the current setup, we need at least two vocabs: src and trg - ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies."); - vocabs.resize(vfiles.size()); - std::unordered_map> vmap; - for (size_t i = 0; i < vocabs.size(); ++i) { - auto m = vmap.emplace(std::make_pair(vfiles[i], marian::Ptr())); - if (m.second) { // new: load the vocab - m.first->second = marian::New(options, i); - m.first->second->load(vfiles[i]); - } - vocabs[i] = m.first->second; - } - } - return vocabs; -} - namespace marian { namespace bergamot { Service::Service(Ptr options, MemoryBundle memoryBundle) : requestId_(0), options_(options), - vocabs_(std::move(loadVocabularies(options, std::move(memoryBundle.vocabs)))), + vocabs_(options, std::move(memoryBundle.vocabs)), text_processor_(vocabs_, options), batcher_(options), numWorkers_(options->get("cpu-threads")), modelMemory_(std::move(memoryBundle.model)), diff --git a/src/translator/service.h b/src/translator/service.h index 9d0a67d..6c60888 100644 --- a/src/translator/service.h +++ b/src/translator/service.h @@ -9,6 +9,7 @@ #include "response_builder.h" #include "text_processor.h" #include "translator/parser.h" +#include "vocabs.h" #ifndef WASM_COMPATIBLE_SOURCE #include "pcqueue.h" @@ -172,7 +173,7 @@ private: size_t requestId_; /// Store vocabs representing source and target. - std::vector> vocabs_; // ORDER DEPENDENCY (text_processor_) + Vocabs vocabs_; // ORDER DEPENDENCY (text_processor_) /// TextProcesser takes a blob of text and converts into format consumable by /// the batch-translator and annotates sentences and words. diff --git a/src/translator/text_processor.cpp b/src/translator/text_processor.cpp index fb66901..457e2b9 100644 --- a/src/translator/text_processor.cpp +++ b/src/translator/text_processor.cpp @@ -4,7 +4,6 @@ #include "annotation.h" #include "common/options.h" -#include "data/vocab.h" #include namespace marian { @@ -12,13 +11,14 @@ namespace bergamot { Segment TextProcessor::tokenize(const string_view &segment, std::vector &wordRanges) { - return vocabs_->front()->encodeWithByteRanges( + // vocabs_->sources().front() is invoked as we currently only support one source vocab + return vocabs_.sources().front()->encodeWithByteRanges( segment, wordRanges, /*addEOS=*/false, /*inference=*/true); } -TextProcessor::TextProcessor(std::vector> &vocabs, +TextProcessor::TextProcessor(Vocabs &vocabs, Ptr options) - : vocabs_(&vocabs), sentence_splitter_(options) { + : vocabs_(vocabs), sentence_splitter_(options) { max_length_break_ = options->get("max-length-break"); max_length_break_ = max_length_break_ - 1; diff --git a/src/translator/text_processor.h b/src/translator/text_processor.h index 698e36e..f5d4d88 100644 --- a/src/translator/text_processor.h +++ b/src/translator/text_processor.h @@ -7,6 +7,7 @@ #include "annotation.h" #include "sentence_splitter.h" +#include "vocabs.h" #include @@ -21,7 +22,7 @@ class TextProcessor { // sentences (vector of words). In addition, the ByteRanges of the // source-tokens in unnormalized text are provided as string_views. public: - explicit TextProcessor(std::vector> &vocabs, Ptr); + explicit TextProcessor(Vocabs &vocabs, Ptr); void process(AnnotatedText &source, Segments &segments); @@ -36,9 +37,10 @@ private: Segments &segments, AnnotatedText &source); // shorthand, used only in truncate() - const Word sourceEosId() const { return vocabs_->front()->getEosId(); } + // vocabs_->sources().front() is invoked as we currently only support one source vocab + const Word sourceEosId() const { return vocabs_.sources().front()->getEosId(); } - std::vector> *vocabs_; + const Vocabs& vocabs_; SentenceSplitter sentence_splitter_; size_t max_length_break_; }; diff --git a/src/translator/vocabs.h b/src/translator/vocabs.h new file mode 100644 index 0000000..89aed4b --- /dev/null +++ b/src/translator/vocabs.h @@ -0,0 +1,81 @@ +#pragma once + +namespace marian { +namespace bergamot { + +/// Wrapper of Marian Vocab objects needed for translator. +/// Holds multiple source vocabularies and one target vocabulary +class Vocabs { +public: + /// Construct vocabs object from either byte-arrays or files + Vocabs(Ptr options, std::vector>&& vocabMemories): options_(options){ + if (!vocabMemories.empty()){ + // load vocabs from buffer + load(std::move(vocabMemories)); + } + else{ + // load vocabs from file + auto vocabPaths = options->get>("vocabs"); + load(vocabPaths); + } + } + + /// Get all source vocabularies (as a vector) + const std::vector>& sources() const { + return srcVocabs_; + } + + /// Get the target vocabulary + const Ptr& target() const { + return trgVocab_; + } + +private: + std::vector> srcVocabs_; // source vocabularies + Ptr trgVocab_; // target vocabulary + Ptr options_; + + // load from buffer + void load(std::vector>&& vocabMemories) { + // At least two vocabs: src and trg + ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies."); + srcVocabs_.resize(vocabMemories.size()); + // hashMap is introduced to avoid double loading the same vocab + // loading vocabs (either from buffers or files) is the biggest bottleneck of the speed + // uintptr_t holds unique keys (address) for share_ptr + std::unordered_map> vmap; + for (size_t i = 0; i < srcVocabs_.size(); i++) { + auto m = vmap.emplace(std::make_pair(reinterpret_cast(vocabMemories[i].get()), Ptr())); + if (m.second) { // new: load the vocab + m.first->second = New(options_, i); + m.first->second->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size())); + } + srcVocabs_[i] = m.first->second; + } + // Initialize target vocab + trgVocab_ = srcVocabs_.back(); + srcVocabs_.pop_back(); + } + + // load from file + void load(const std::vector& vocabPaths){ + // with the current setup, we need at least two vocabs: src and trg + ABORT_IF(vocabPaths.size() < 2, "Insufficient number of vocabularies."); + srcVocabs_.resize(vocabPaths.size()); + std::unordered_map> vmap; + for (size_t i = 0; i < srcVocabs_.size(); ++i) { + auto m = vmap.emplace(std::make_pair(vocabPaths[i], Ptr())); + if (m.second) { // new: load the vocab + m.first->second = New(options_, i); + m.first->second->load(vocabPaths[i]); + } + srcVocabs_[i] = m.first->second; + } + // Initialize target vocab + trgVocab_ = srcVocabs_.back(); + srcVocabs_.pop_back(); + } +}; + +} // namespace bergamot +} // namespace marian