Refactor vocabs in Service (#143)

Co-authored-by: Nikolay Bogoychev <nheart@gmail.com>
This commit is contained in:
Qianqian Zhu 2021-05-17 13:09:03 +01:00 committed by GitHub
parent 77424a3df1
commit 5bd1fc6b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 111 additions and 81 deletions

View File

@ -10,11 +10,11 @@ namespace marian {
namespace bergamot {
BatchTranslator::BatchTranslator(DeviceId const device,
std::vector<Ptr<Vocab const>> &vocabs,
Vocabs &vocabs,
Ptr<Options> options,
const AlignedMemory* modelMemory,
const AlignedMemory* shortlistMemory)
: device_(device), options_(options), vocabs_(&vocabs),
: device_(device), options_(options), vocabs_(vocabs),
modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {}
void BatchTranslator::initialize() {
@ -22,17 +22,17 @@ void BatchTranslator::initialize() {
bool check = options_->get<bool>("check-bytearray",false); // Flag holds whether validate the bytearray (model and shortlist)
if (options_->hasAndNotEmpty("shortlist")) {
int srcIdx = 0, trgIdx = 1;
bool shared_vcb = vocabs_->front() == vocabs_->back();
bool shared_vcb = vocabs_.sources().front() == vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab
if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) {
slgen_ = New<data::BinaryShortlistGenerator>(shortlistMemory_->begin(), shortlistMemory_->size(),
vocabs_->front(), vocabs_->back(),
srcIdx, trgIdx, shared_vcb, check);
vocabs_.sources().front(), vocabs_.target(),
srcIdx, trgIdx, shared_vcb, check);
}
else {
// Changed to BinaryShortlistGenerator to enable loading binary shortlist file
// This class also supports text shortlist file
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_->front(),
vocabs_->back(), srcIdx,
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(),
vocabs_.target(), srcIdx,
trgIdx, shared_vcb);
}
}
@ -97,7 +97,7 @@ void BatchTranslator::translate(Batch &batch) {
std::vector<Ptr<SubBatch>> subBatches;
for (size_t j = 0; j < maxDims.size(); ++j) {
subBatches.emplace_back(
New<SubBatch>(batchSize, maxDims[j], vocabs_->at(j)));
New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
}
std::vector<size_t> words(maxDims.size(), 0);
@ -116,9 +116,8 @@ void BatchTranslator::translate(Batch &batch) {
auto corpus_batch = Ptr<CorpusBatch>(new CorpusBatch(subBatches));
corpus_batch->setSentenceIds(sentenceIds);
auto trgVocab = vocabs_->back();
auto search = New<BeamSearch>(options_, scorers_, trgVocab);
auto search = New<BeamSearch>(options_, scorers_, vocabs_.target());
auto histories = std::move(search->search(graph_, corpus_batch));
batch.completeBatch(histories);

View File

@ -11,6 +11,7 @@
#include "request.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "vocabs.h"
#ifndef WASM_COMPATIBLE_SOURCE
#include "pcqueue.h"
@ -34,7 +35,7 @@ public:
* @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not used.
* @param shortlistMemory byte array of shortlist (aligned to 64)
*/
explicit BatchTranslator(DeviceId const device, std::vector<Ptr<Vocab const>> &vocabs,
explicit BatchTranslator(DeviceId const device, Vocabs &vocabs,
Ptr<Options> options, const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory);
// convenience function for logging. TODO(jerin)
@ -45,7 +46,7 @@ public:
private:
Ptr<Options> options_;
DeviceId device_;
std::vector<Ptr<Vocab const>> *vocabs_;
const Vocabs& vocabs_;
Ptr<ExpressionGraph> graph_;
std::vector<Ptr<Scorer>> scorers_;
Ptr<data::ShortlistGenerator const> slgen_;

View File

@ -28,27 +28,6 @@ struct MemoryBundle {
/// @todo Not implemented yet
AlignedMemory ssplitPrefixFile;
MemoryBundle() = default;
MemoryBundle(MemoryBundle &&from){
model = std::move(from.model);
shortlist = std::move(from.shortlist);
vocabs = std::move(vocabs);
ssplitPrefixFile = std::move(from.ssplitPrefixFile);
}
MemoryBundle &operator=(MemoryBundle &&from) {
model = std::move(from.model);
shortlist = std::move(from.shortlist);
vocabs = std::move(vocabs);
ssplitPrefixFile = std::move(from.ssplitPrefixFile);
return *this;
}
// Delete copy constructors
MemoryBundle(const MemoryBundle&) = delete;
MemoryBundle& operator=(const MemoryBundle&) = delete;
};
} // namespace bergamot

View File

@ -65,11 +65,10 @@ void ResponseBuilder::buildTranslatedText(Histories &histories,
Result result = onebest[0]; // Expecting only one result;
Words words = std::get<0>(result);
auto targetVocab = vocabs_->back();
std::string decoded;
std::vector<string_view> targetSentenceMappings;
targetVocab->decodeWithByteRanges(words, decoded, targetSentenceMappings);
vocabs_.target()->decodeWithByteRanges(words, decoded, targetSentenceMappings);
switch (responseOptions_.concatStrategy) {
case ConcatStrategy::FAITHFUL: {

View File

@ -4,6 +4,7 @@
#include "data/types.h"
#include "response.h"
#include "response_options.h"
#include "vocabs.h"
// For now we will work with this, to avoid complaints another structure is hard
// to operate with.
@ -24,10 +25,10 @@ public:
/// @param [in] vocabs: marian vocab object (used in decoding)
/// @param [in] promise: promise to set with the constructed Response.
ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source,
std::vector<Ptr<Vocab const>> &vocabs,
Vocabs &vocabs,
std::promise<Response> &&promise)
: responseOptions_(responseOptions), source_(std::move(source)),
vocabs_(&vocabs), promise_(std::move(promise)) {}
vocabs_(vocabs), promise_(std::move(promise)) {}
/// Constructs and sets the promise of a Response object from obtained
/// histories after translating.
@ -81,7 +82,7 @@ private:
// Data members are context/curried args for the functor.
ResponseOptions responseOptions_;
std::vector<Ptr<Vocab const>> *vocabs_; // vocabs are required for decoding
const Vocabs& vocabs_; // vocabs are required for decoding
// and any source validation checks.
std::promise<Response> promise_; // To be set when callback triggered and
// after Response constructed.

View File

@ -5,45 +5,12 @@
#include <string>
#include <utility>
inline std::vector<marian::Ptr<const marian::Vocab>>
loadVocabularies(marian::Ptr<marian::Options> options,
std::vector<std::shared_ptr<marian::bergamot::AlignedMemory>>&& vocabMemories) {
// @TODO: parallelize vocab loading for faster startup
std::vector<marian::Ptr<marian::Vocab const>> vocabs;
if(!vocabMemories.empty()){
// load vocabs from buffer
ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies.");
vocabs.resize(vocabMemories.size());
for (size_t i = 0; i < vocabs.size(); i++) {
marian::Ptr<marian::Vocab> vocab = marian::New<marian::Vocab>(options, i);
vocab->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size()));
vocabs[i] = vocab;
}
} else {
// load vocabs from file
auto vfiles = options->get<std::vector<std::string>>("vocabs");
// with the current setup, we need at least two vocabs: src and trg
ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies.");
vocabs.resize(vfiles.size());
std::unordered_map<std::string, marian::Ptr<marian::Vocab>> vmap;
for (size_t i = 0; i < vocabs.size(); ++i) {
auto m = vmap.emplace(std::make_pair(vfiles[i], marian::Ptr<marian::Vocab>()));
if (m.second) { // new: load the vocab
m.first->second = marian::New<marian::Vocab>(options, i);
m.first->second->load(vfiles[i]);
}
vocabs[i] = m.first->second;
}
}
return vocabs;
}
namespace marian {
namespace bergamot {
Service::Service(Ptr<Options> options, MemoryBundle memoryBundle)
: requestId_(0), options_(options),
vocabs_(std::move(loadVocabularies(options, std::move(memoryBundle.vocabs)))),
vocabs_(options, std::move(memoryBundle.vocabs)),
text_processor_(vocabs_, options), batcher_(options),
numWorkers_(options->get<int>("cpu-threads")),
modelMemory_(std::move(memoryBundle.model)),

View File

@ -9,6 +9,7 @@
#include "response_builder.h"
#include "text_processor.h"
#include "translator/parser.h"
#include "vocabs.h"
#ifndef WASM_COMPATIBLE_SOURCE
#include "pcqueue.h"
@ -172,7 +173,7 @@ private:
size_t requestId_;
/// Store vocabs representing source and target.
std::vector<Ptr<Vocab const>> vocabs_; // ORDER DEPENDENCY (text_processor_)
Vocabs vocabs_; // ORDER DEPENDENCY (text_processor_)
/// TextProcesser takes a blob of text and converts into format consumable by
/// the batch-translator and annotates sentences and words.

View File

@ -4,7 +4,6 @@
#include "annotation.h"
#include "common/options.h"
#include "data/vocab.h"
#include <vector>
namespace marian {
@ -12,13 +11,14 @@ namespace bergamot {
Segment TextProcessor::tokenize(const string_view &segment,
std::vector<string_view> &wordRanges) {
return vocabs_->front()->encodeWithByteRanges(
// vocabs_->sources().front() is invoked as we currently only support one source vocab
return vocabs_.sources().front()->encodeWithByteRanges(
segment, wordRanges, /*addEOS=*/false, /*inference=*/true);
}
TextProcessor::TextProcessor(std::vector<Ptr<Vocab const>> &vocabs,
TextProcessor::TextProcessor(Vocabs &vocabs,
Ptr<Options> options)
: vocabs_(&vocabs), sentence_splitter_(options) {
: vocabs_(vocabs), sentence_splitter_(options) {
max_length_break_ = options->get<int>("max-length-break");
max_length_break_ = max_length_break_ - 1;

View File

@ -7,6 +7,7 @@
#include "annotation.h"
#include "sentence_splitter.h"
#include "vocabs.h"
#include <vector>
@ -21,7 +22,7 @@ class TextProcessor {
// sentences (vector of words). In addition, the ByteRanges of the
// source-tokens in unnormalized text are provided as string_views.
public:
explicit TextProcessor(std::vector<Ptr<Vocab const>> &vocabs, Ptr<Options>);
explicit TextProcessor(Vocabs &vocabs, Ptr<Options>);
void process(AnnotatedText &source, Segments &segments);
@ -36,9 +37,10 @@ private:
Segments &segments, AnnotatedText &source);
// shorthand, used only in truncate()
const Word sourceEosId() const { return vocabs_->front()->getEosId(); }
// vocabs_->sources().front() is invoked as we currently only support one source vocab
const Word sourceEosId() const { return vocabs_.sources().front()->getEosId(); }
std::vector<Ptr<Vocab const>> *vocabs_;
const Vocabs& vocabs_;
SentenceSplitter sentence_splitter_;
size_t max_length_break_;
};

81
src/translator/vocabs.h Normal file
View File

@ -0,0 +1,81 @@
#pragma once
namespace marian {
namespace bergamot {
/// Wrapper of Marian Vocab objects needed for translator.
/// Holds multiple source vocabularies and one target vocabulary
class Vocabs {
public:
/// Construct vocabs object from either byte-arrays or files
Vocabs(Ptr<Options> options, std::vector<std::shared_ptr<AlignedMemory>>&& vocabMemories): options_(options){
if (!vocabMemories.empty()){
// load vocabs from buffer
load(std::move(vocabMemories));
}
else{
// load vocabs from file
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");
load(vocabPaths);
}
}
/// Get all source vocabularies (as a vector)
const std::vector<Ptr<Vocab const>>& sources() const {
return srcVocabs_;
}
/// Get the target vocabulary
const Ptr<Vocab const>& target() const {
return trgVocab_;
}
private:
std::vector<Ptr<Vocab const>> srcVocabs_; // source vocabularies
Ptr<Vocab const> trgVocab_; // target vocabulary
Ptr<Options> options_;
// load from buffer
void load(std::vector<std::shared_ptr<AlignedMemory>>&& vocabMemories) {
// At least two vocabs: src and trg
ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies.");
srcVocabs_.resize(vocabMemories.size());
// hashMap is introduced to avoid double loading the same vocab
// loading vocabs (either from buffers or files) is the biggest bottleneck of the speed
// uintptr_t holds unique keys (address) for share_ptr<AlignedMemory>
std::unordered_map<uintptr_t, Ptr<Vocab>> vmap;
for (size_t i = 0; i < srcVocabs_.size(); i++) {
auto m = vmap.emplace(std::make_pair(reinterpret_cast<uintptr_t>(vocabMemories[i].get()), Ptr<Vocab>()));
if (m.second) { // new: load the vocab
m.first->second = New<Vocab>(options_, i);
m.first->second->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size()));
}
srcVocabs_[i] = m.first->second;
}
// Initialize target vocab
trgVocab_ = srcVocabs_.back();
srcVocabs_.pop_back();
}
// load from file
void load(const std::vector<std::string>& vocabPaths){
// with the current setup, we need at least two vocabs: src and trg
ABORT_IF(vocabPaths.size() < 2, "Insufficient number of vocabularies.");
srcVocabs_.resize(vocabPaths.size());
std::unordered_map<std::string, Ptr<Vocab>> vmap;
for (size_t i = 0; i < srcVocabs_.size(); ++i) {
auto m = vmap.emplace(std::make_pair(vocabPaths[i], Ptr<Vocab>()));
if (m.second) { // new: load the vocab
m.first->second = New<Vocab>(options_, i);
m.first->second->load(vocabPaths[i]);
}
srcVocabs_[i] = m.first->second;
}
// Initialize target vocab
trgVocab_ = srcVocabs_.back();
srcVocabs_.pop_back();
}
};
} // namespace bergamot
} // namespace marian