mirror of
https://github.com/browsermt/bergamot-translator.git
synced 2024-08-15 16:40:26 +03:00
Refactor vocabs in Service (#143)
Co-authored-by: Nikolay Bogoychev <nheart@gmail.com>
This commit is contained in:
parent
77424a3df1
commit
5bd1fc6b83
@ -10,11 +10,11 @@ namespace marian {
|
||||
namespace bergamot {
|
||||
|
||||
BatchTranslator::BatchTranslator(DeviceId const device,
|
||||
std::vector<Ptr<Vocab const>> &vocabs,
|
||||
Vocabs &vocabs,
|
||||
Ptr<Options> options,
|
||||
const AlignedMemory* modelMemory,
|
||||
const AlignedMemory* shortlistMemory)
|
||||
: device_(device), options_(options), vocabs_(&vocabs),
|
||||
: device_(device), options_(options), vocabs_(vocabs),
|
||||
modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {}
|
||||
|
||||
void BatchTranslator::initialize() {
|
||||
@ -22,17 +22,17 @@ void BatchTranslator::initialize() {
|
||||
bool check = options_->get<bool>("check-bytearray",false); // Flag holds whether validate the bytearray (model and shortlist)
|
||||
if (options_->hasAndNotEmpty("shortlist")) {
|
||||
int srcIdx = 0, trgIdx = 1;
|
||||
bool shared_vcb = vocabs_->front() == vocabs_->back();
|
||||
bool shared_vcb = vocabs_.sources().front() == vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab
|
||||
if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) {
|
||||
slgen_ = New<data::BinaryShortlistGenerator>(shortlistMemory_->begin(), shortlistMemory_->size(),
|
||||
vocabs_->front(), vocabs_->back(),
|
||||
srcIdx, trgIdx, shared_vcb, check);
|
||||
vocabs_.sources().front(), vocabs_.target(),
|
||||
srcIdx, trgIdx, shared_vcb, check);
|
||||
}
|
||||
else {
|
||||
// Changed to BinaryShortlistGenerator to enable loading binary shortlist file
|
||||
// This class also supports text shortlist file
|
||||
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_->front(),
|
||||
vocabs_->back(), srcIdx,
|
||||
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(),
|
||||
vocabs_.target(), srcIdx,
|
||||
trgIdx, shared_vcb);
|
||||
}
|
||||
}
|
||||
@ -97,7 +97,7 @@ void BatchTranslator::translate(Batch &batch) {
|
||||
std::vector<Ptr<SubBatch>> subBatches;
|
||||
for (size_t j = 0; j < maxDims.size(); ++j) {
|
||||
subBatches.emplace_back(
|
||||
New<SubBatch>(batchSize, maxDims[j], vocabs_->at(j)));
|
||||
New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
|
||||
}
|
||||
|
||||
std::vector<size_t> words(maxDims.size(), 0);
|
||||
@ -116,9 +116,8 @@ void BatchTranslator::translate(Batch &batch) {
|
||||
|
||||
auto corpus_batch = Ptr<CorpusBatch>(new CorpusBatch(subBatches));
|
||||
corpus_batch->setSentenceIds(sentenceIds);
|
||||
|
||||
auto trgVocab = vocabs_->back();
|
||||
auto search = New<BeamSearch>(options_, scorers_, trgVocab);
|
||||
|
||||
auto search = New<BeamSearch>(options_, scorers_, vocabs_.target());
|
||||
|
||||
auto histories = std::move(search->search(graph_, corpus_batch));
|
||||
batch.completeBatch(histories);
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "request.h"
|
||||
#include "translator/history.h"
|
||||
#include "translator/scorers.h"
|
||||
#include "vocabs.h"
|
||||
|
||||
#ifndef WASM_COMPATIBLE_SOURCE
|
||||
#include "pcqueue.h"
|
||||
@ -34,7 +35,7 @@ public:
|
||||
* @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not used.
|
||||
* @param shortlistMemory byte array of shortlist (aligned to 64)
|
||||
*/
|
||||
explicit BatchTranslator(DeviceId const device, std::vector<Ptr<Vocab const>> &vocabs,
|
||||
explicit BatchTranslator(DeviceId const device, Vocabs &vocabs,
|
||||
Ptr<Options> options, const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory);
|
||||
|
||||
// convenience function for logging. TODO(jerin)
|
||||
@ -45,7 +46,7 @@ public:
|
||||
private:
|
||||
Ptr<Options> options_;
|
||||
DeviceId device_;
|
||||
std::vector<Ptr<Vocab const>> *vocabs_;
|
||||
const Vocabs& vocabs_;
|
||||
Ptr<ExpressionGraph> graph_;
|
||||
std::vector<Ptr<Scorer>> scorers_;
|
||||
Ptr<data::ShortlistGenerator const> slgen_;
|
||||
|
@ -28,27 +28,6 @@ struct MemoryBundle {
|
||||
|
||||
/// @todo Not implemented yet
|
||||
AlignedMemory ssplitPrefixFile;
|
||||
|
||||
MemoryBundle() = default;
|
||||
|
||||
MemoryBundle(MemoryBundle &&from){
|
||||
model = std::move(from.model);
|
||||
shortlist = std::move(from.shortlist);
|
||||
vocabs = std::move(vocabs);
|
||||
ssplitPrefixFile = std::move(from.ssplitPrefixFile);
|
||||
}
|
||||
|
||||
MemoryBundle &operator=(MemoryBundle &&from) {
|
||||
model = std::move(from.model);
|
||||
shortlist = std::move(from.shortlist);
|
||||
vocabs = std::move(vocabs);
|
||||
ssplitPrefixFile = std::move(from.ssplitPrefixFile);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Delete copy constructors
|
||||
MemoryBundle(const MemoryBundle&) = delete;
|
||||
MemoryBundle& operator=(const MemoryBundle&) = delete;
|
||||
};
|
||||
|
||||
} // namespace bergamot
|
||||
|
@ -65,11 +65,10 @@ void ResponseBuilder::buildTranslatedText(Histories &histories,
|
||||
|
||||
Result result = onebest[0]; // Expecting only one result;
|
||||
Words words = std::get<0>(result);
|
||||
auto targetVocab = vocabs_->back();
|
||||
|
||||
std::string decoded;
|
||||
std::vector<string_view> targetSentenceMappings;
|
||||
targetVocab->decodeWithByteRanges(words, decoded, targetSentenceMappings);
|
||||
vocabs_.target()->decodeWithByteRanges(words, decoded, targetSentenceMappings);
|
||||
|
||||
switch (responseOptions_.concatStrategy) {
|
||||
case ConcatStrategy::FAITHFUL: {
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "data/types.h"
|
||||
#include "response.h"
|
||||
#include "response_options.h"
|
||||
#include "vocabs.h"
|
||||
|
||||
// For now we will work with this, to avoid complaints another structure is hard
|
||||
// to operate with.
|
||||
@ -24,10 +25,10 @@ public:
|
||||
/// @param [in] vocabs: marian vocab object (used in decoding)
|
||||
/// @param [in] promise: promise to set with the constructed Response.
|
||||
ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source,
|
||||
std::vector<Ptr<Vocab const>> &vocabs,
|
||||
Vocabs &vocabs,
|
||||
std::promise<Response> &&promise)
|
||||
: responseOptions_(responseOptions), source_(std::move(source)),
|
||||
vocabs_(&vocabs), promise_(std::move(promise)) {}
|
||||
vocabs_(vocabs), promise_(std::move(promise)) {}
|
||||
|
||||
/// Constructs and sets the promise of a Response object from obtained
|
||||
/// histories after translating.
|
||||
@ -81,7 +82,7 @@ private:
|
||||
// Data members are context/curried args for the functor.
|
||||
|
||||
ResponseOptions responseOptions_;
|
||||
std::vector<Ptr<Vocab const>> *vocabs_; // vocabs are required for decoding
|
||||
const Vocabs& vocabs_; // vocabs are required for decoding
|
||||
// and any source validation checks.
|
||||
std::promise<Response> promise_; // To be set when callback triggered and
|
||||
// after Response constructed.
|
||||
|
@ -5,45 +5,12 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
inline std::vector<marian::Ptr<const marian::Vocab>>
|
||||
loadVocabularies(marian::Ptr<marian::Options> options,
|
||||
std::vector<std::shared_ptr<marian::bergamot::AlignedMemory>>&& vocabMemories) {
|
||||
// @TODO: parallelize vocab loading for faster startup
|
||||
std::vector<marian::Ptr<marian::Vocab const>> vocabs;
|
||||
if(!vocabMemories.empty()){
|
||||
// load vocabs from buffer
|
||||
ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies.");
|
||||
vocabs.resize(vocabMemories.size());
|
||||
for (size_t i = 0; i < vocabs.size(); i++) {
|
||||
marian::Ptr<marian::Vocab> vocab = marian::New<marian::Vocab>(options, i);
|
||||
vocab->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size()));
|
||||
vocabs[i] = vocab;
|
||||
}
|
||||
} else {
|
||||
// load vocabs from file
|
||||
auto vfiles = options->get<std::vector<std::string>>("vocabs");
|
||||
// with the current setup, we need at least two vocabs: src and trg
|
||||
ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies.");
|
||||
vocabs.resize(vfiles.size());
|
||||
std::unordered_map<std::string, marian::Ptr<marian::Vocab>> vmap;
|
||||
for (size_t i = 0; i < vocabs.size(); ++i) {
|
||||
auto m = vmap.emplace(std::make_pair(vfiles[i], marian::Ptr<marian::Vocab>()));
|
||||
if (m.second) { // new: load the vocab
|
||||
m.first->second = marian::New<marian::Vocab>(options, i);
|
||||
m.first->second->load(vfiles[i]);
|
||||
}
|
||||
vocabs[i] = m.first->second;
|
||||
}
|
||||
}
|
||||
return vocabs;
|
||||
}
|
||||
|
||||
namespace marian {
|
||||
namespace bergamot {
|
||||
|
||||
Service::Service(Ptr<Options> options, MemoryBundle memoryBundle)
|
||||
: requestId_(0), options_(options),
|
||||
vocabs_(std::move(loadVocabularies(options, std::move(memoryBundle.vocabs)))),
|
||||
vocabs_(options, std::move(memoryBundle.vocabs)),
|
||||
text_processor_(vocabs_, options), batcher_(options),
|
||||
numWorkers_(options->get<int>("cpu-threads")),
|
||||
modelMemory_(std::move(memoryBundle.model)),
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "response_builder.h"
|
||||
#include "text_processor.h"
|
||||
#include "translator/parser.h"
|
||||
#include "vocabs.h"
|
||||
|
||||
#ifndef WASM_COMPATIBLE_SOURCE
|
||||
#include "pcqueue.h"
|
||||
@ -172,7 +173,7 @@ private:
|
||||
|
||||
size_t requestId_;
|
||||
/// Store vocabs representing source and target.
|
||||
std::vector<Ptr<Vocab const>> vocabs_; // ORDER DEPENDENCY (text_processor_)
|
||||
Vocabs vocabs_; // ORDER DEPENDENCY (text_processor_)
|
||||
|
||||
/// TextProcesser takes a blob of text and converts into format consumable by
|
||||
/// the batch-translator and annotates sentences and words.
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include "annotation.h"
|
||||
|
||||
#include "common/options.h"
|
||||
#include "data/vocab.h"
|
||||
#include <vector>
|
||||
|
||||
namespace marian {
|
||||
@ -12,13 +11,14 @@ namespace bergamot {
|
||||
|
||||
Segment TextProcessor::tokenize(const string_view &segment,
|
||||
std::vector<string_view> &wordRanges) {
|
||||
return vocabs_->front()->encodeWithByteRanges(
|
||||
// vocabs_->sources().front() is invoked as we currently only support one source vocab
|
||||
return vocabs_.sources().front()->encodeWithByteRanges(
|
||||
segment, wordRanges, /*addEOS=*/false, /*inference=*/true);
|
||||
}
|
||||
|
||||
TextProcessor::TextProcessor(std::vector<Ptr<Vocab const>> &vocabs,
|
||||
TextProcessor::TextProcessor(Vocabs &vocabs,
|
||||
Ptr<Options> options)
|
||||
: vocabs_(&vocabs), sentence_splitter_(options) {
|
||||
: vocabs_(vocabs), sentence_splitter_(options) {
|
||||
|
||||
max_length_break_ = options->get<int>("max-length-break");
|
||||
max_length_break_ = max_length_break_ - 1;
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "annotation.h"
|
||||
|
||||
#include "sentence_splitter.h"
|
||||
#include "vocabs.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -21,7 +22,7 @@ class TextProcessor {
|
||||
// sentences (vector of words). In addition, the ByteRanges of the
|
||||
// source-tokens in unnormalized text are provided as string_views.
|
||||
public:
|
||||
explicit TextProcessor(std::vector<Ptr<Vocab const>> &vocabs, Ptr<Options>);
|
||||
explicit TextProcessor(Vocabs &vocabs, Ptr<Options>);
|
||||
|
||||
void process(AnnotatedText &source, Segments &segments);
|
||||
|
||||
@ -36,9 +37,10 @@ private:
|
||||
Segments &segments, AnnotatedText &source);
|
||||
|
||||
// shorthand, used only in truncate()
|
||||
const Word sourceEosId() const { return vocabs_->front()->getEosId(); }
|
||||
// vocabs_->sources().front() is invoked as we currently only support one source vocab
|
||||
const Word sourceEosId() const { return vocabs_.sources().front()->getEosId(); }
|
||||
|
||||
std::vector<Ptr<Vocab const>> *vocabs_;
|
||||
const Vocabs& vocabs_;
|
||||
SentenceSplitter sentence_splitter_;
|
||||
size_t max_length_break_;
|
||||
};
|
||||
|
81
src/translator/vocabs.h
Normal file
81
src/translator/vocabs.h
Normal file
@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
namespace marian {
|
||||
namespace bergamot {
|
||||
|
||||
/// Wrapper of Marian Vocab objects needed for translator.
|
||||
/// Holds multiple source vocabularies and one target vocabulary
|
||||
class Vocabs {
|
||||
public:
|
||||
/// Construct vocabs object from either byte-arrays or files
|
||||
Vocabs(Ptr<Options> options, std::vector<std::shared_ptr<AlignedMemory>>&& vocabMemories): options_(options){
|
||||
if (!vocabMemories.empty()){
|
||||
// load vocabs from buffer
|
||||
load(std::move(vocabMemories));
|
||||
}
|
||||
else{
|
||||
// load vocabs from file
|
||||
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");
|
||||
load(vocabPaths);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all source vocabularies (as a vector)
|
||||
const std::vector<Ptr<Vocab const>>& sources() const {
|
||||
return srcVocabs_;
|
||||
}
|
||||
|
||||
/// Get the target vocabulary
|
||||
const Ptr<Vocab const>& target() const {
|
||||
return trgVocab_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Ptr<Vocab const>> srcVocabs_; // source vocabularies
|
||||
Ptr<Vocab const> trgVocab_; // target vocabulary
|
||||
Ptr<Options> options_;
|
||||
|
||||
// load from buffer
|
||||
void load(std::vector<std::shared_ptr<AlignedMemory>>&& vocabMemories) {
|
||||
// At least two vocabs: src and trg
|
||||
ABORT_IF(vocabMemories.size() < 2, "Insufficient number of vocabularies.");
|
||||
srcVocabs_.resize(vocabMemories.size());
|
||||
// hashMap is introduced to avoid double loading the same vocab
|
||||
// loading vocabs (either from buffers or files) is the biggest bottleneck of the speed
|
||||
// uintptr_t holds unique keys (address) for share_ptr<AlignedMemory>
|
||||
std::unordered_map<uintptr_t, Ptr<Vocab>> vmap;
|
||||
for (size_t i = 0; i < srcVocabs_.size(); i++) {
|
||||
auto m = vmap.emplace(std::make_pair(reinterpret_cast<uintptr_t>(vocabMemories[i].get()), Ptr<Vocab>()));
|
||||
if (m.second) { // new: load the vocab
|
||||
m.first->second = New<Vocab>(options_, i);
|
||||
m.first->second->loadFromSerialized(absl::string_view(vocabMemories[i]->begin(), vocabMemories[i]->size()));
|
||||
}
|
||||
srcVocabs_[i] = m.first->second;
|
||||
}
|
||||
// Initialize target vocab
|
||||
trgVocab_ = srcVocabs_.back();
|
||||
srcVocabs_.pop_back();
|
||||
}
|
||||
|
||||
// load from file
|
||||
void load(const std::vector<std::string>& vocabPaths){
|
||||
// with the current setup, we need at least two vocabs: src and trg
|
||||
ABORT_IF(vocabPaths.size() < 2, "Insufficient number of vocabularies.");
|
||||
srcVocabs_.resize(vocabPaths.size());
|
||||
std::unordered_map<std::string, Ptr<Vocab>> vmap;
|
||||
for (size_t i = 0; i < srcVocabs_.size(); ++i) {
|
||||
auto m = vmap.emplace(std::make_pair(vocabPaths[i], Ptr<Vocab>()));
|
||||
if (m.second) { // new: load the vocab
|
||||
m.first->second = New<Vocab>(options_, i);
|
||||
m.first->second->load(vocabPaths[i]);
|
||||
}
|
||||
srcVocabs_[i] = m.first->second;
|
||||
}
|
||||
// Initialize target vocab
|
||||
trgVocab_ = srcVocabs_.back();
|
||||
srcVocabs_.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace bergamot
|
||||
} // namespace marian
|
Loading…
Reference in New Issue
Block a user