Format update with clang-format

This commit is contained in:
Jerin Philip 2021-05-18 19:53:23 +00:00
parent 95ba38c90d
commit 5340b19eae
29 changed files with 364 additions and 440 deletions

View File

@ -1,8 +1,9 @@
#include "catch.hpp"
#include "translator/annotation.h"
#include <random>
#include <vector>
#include "catch.hpp"
#include "translator/annotation.h"
using namespace marian::bergamot;
TEST_CASE("Test Annotation API with random sentences") {
@ -52,8 +53,7 @@ TEST_CASE("Test Annotation API with random sentences") {
}
std::string text;
for (size_t idx = 0; idx < sentences; idx++) {
if (idx != 0)
text += "\n";
if (idx != 0) text += "\n";
// Words can be zero, we need to support empty word sentences as well.
size_t numWords = randomIntGen_() % maxWords;
@ -105,8 +105,7 @@ TEST_CASE("Test Annotation API with random sentences") {
// the math underneath.
if (debug) {
std::cout << "Inserting words onto container and save ground-truth-table:"
<< std::endl;
std::cout << "Inserting words onto container and save ground-truth-table:" << std::endl;
}
std::vector<std::vector<marian::string_view>> wordStringViews;
@ -115,8 +114,7 @@ TEST_CASE("Test Annotation API with random sentences") {
std::vector<marian::string_view> wordByteRanges;
bool first{true};
for (auto &word : sentence) {
marian::string_view wordView(&testAnnotation.text[word.begin],
word.size());
marian::string_view wordView(&testAnnotation.text[word.begin], word.size());
wordByteRanges.push_back(wordView);
if (debug) {
if (first) {
@ -127,7 +125,8 @@ TEST_CASE("Test Annotation API with random sentences") {
std::cout << std::string(wordView);
}
}
testAnnotation.recordExistingSentence(wordByteRanges.begin(), wordByteRanges.end(), testAnnotation.text.data() + sentence_iter->begin);
testAnnotation.recordExistingSentence(wordByteRanges.begin(), wordByteRanges.end(),
testAnnotation.text.data() + sentence_iter->begin);
++sentence_iter;
wordStringViews.push_back(wordByteRanges);
if (debug) {
@ -136,9 +135,7 @@ TEST_CASE("Test Annotation API with random sentences") {
}
if (debug) {
std::cout
<< "Inserting sentences onto container and save ground-truth-table"
<< std::endl;
std::cout << "Inserting sentences onto container and save ground-truth-table" << std::endl;
}
std::vector<marian::string_view> sentenceStringViews;
for (auto &sentenceByteRange : groundTruthSentences) {
@ -203,7 +200,8 @@ TEST_CASE("Test Annotation API with random sentences") {
// Sentence if the random test above does not cover it for some reason.
int emptySentenceIdx = sentences;
std::vector<marian::string_view> emptySentence;
testAnnotation.recordExistingSentence(emptySentence.begin(), emptySentence.end(), testAnnotation.text.data() + testAnnotation.text.size());
testAnnotation.recordExistingSentence(emptySentence.begin(), emptySentence.end(),
testAnnotation.text.data() + testAnnotation.text.size());
// There are no words.
CHECK(testAnnotation.numWords(emptySentenceIdx) == 0);

View File

@ -1,4 +1,5 @@
#include "annotation.h"
#include <cassert>
namespace marian {
@ -9,7 +10,8 @@ AnnotatedText::AnnotatedText(std::string &&t) : text(std::move(t)) {
annotation.token_begin_.back() = text.size();
}
void AnnotatedText::appendSentence(string_view prefix, std::vector<string_view>::iterator begin, std::vector<string_view>::iterator end) {
void AnnotatedText::appendSentence(string_view prefix, std::vector<string_view>::iterator begin,
std::vector<string_view>::iterator end) {
assert(annotation.token_begin_.back() == text.size());
// We'll be adding tokens from the sentence and another gap.
annotation.token_begin_.reserve(annotation.token_begin_.size() + (end - begin) + 1);
@ -39,7 +41,8 @@ void AnnotatedText::appendEndingWhitespace(string_view whitespace) {
annotation.token_begin_.back() = text.size();
}
void AnnotatedText::recordExistingSentence(std::vector<string_view>::iterator begin, std::vector<string_view>::iterator end, const char *sentence_begin) {
void AnnotatedText::recordExistingSentence(std::vector<string_view>::iterator begin,
std::vector<string_view>::iterator end, const char *sentence_begin) {
assert(sentence_begin >= text.data());
assert(sentence_begin <= text.data() + text.size());
assert(begin == end || sentence_begin == begin->data());

View File

@ -1,11 +1,12 @@
#ifndef BERGAMOT_SENTENCE_RANGES_H_
#define BERGAMOT_SENTENCE_RANGES_H_
#include "data/types.h"
#include <cassert>
#include <utility>
#include <vector>
#include "data/types.h"
namespace marian {
namespace bergamot {
@ -143,9 +144,7 @@ public:
/// string_views. Since this tracks only prefix, remember
/// appendEndingWhitespace.
/// The string_views must not already be in text.
void appendSentence(
string_view prefix,
std::vector<string_view>::iterator tokens_begin,
void appendSentence(string_view prefix, std::vector<string_view>::iterator tokens_begin,
std::vector<string_view>::iterator tokens_end);
/// Append the whitespace at the end of input. string_view must not be in
@ -158,18 +157,14 @@ public:
/// Normally the beginning of the sentence can be inferred from
/// tokens_begin->data() but the tokens could be empty, so sentence_begin is
/// required to know where the sentence is.
void recordExistingSentence(
std::vector<string_view>::iterator tokens_begin,
std::vector<string_view>::iterator tokens_end,
const char *sentence_begin);
void recordExistingSentence(std::vector<string_view>::iterator tokens_begin,
std::vector<string_view>::iterator tokens_end, const char *sentence_begin);
/// Returns the number of sentences in the annotation structure.
const size_t numSentences() const { return annotation.numSentences(); }
/// Returns number of words in the sentece identified by sentenceIdx.
const size_t numWords(size_t sentenceIdx) const {
return annotation.numWords(sentenceIdx);
}
const size_t numWords(size_t sentenceIdx) const { return annotation.numWords(sentenceIdx); }
/// Returns a string_view representing wordIdx in sentenceIdx
string_view word(size_t sentenceIdx, size_t wordIdx) const {
@ -177,9 +172,7 @@ public:
}
/// Returns a string_view representing sentence corresponding to sentenceIdx.
string_view sentence(size_t sentenceIdx) const {
return asStringView(annotation.sentence(sentenceIdx));
}
string_view sentence(size_t sentenceIdx) const { return asStringView(annotation.sentence(sentenceIdx)); }
/// Returns the string_view of the gap between two sentences in the container.
///
@ -191,19 +184,13 @@ public:
/// * For `i = N`, the gap between the last (N-1th) sentence and end of
/// text.
/// @param sentenceIdx: Can be between `[0, numSentences()]`.
string_view gap(size_t sentenceIdx) const {
return asStringView(annotation.gap(sentenceIdx));
}
string_view gap(size_t sentenceIdx) const { return asStringView(annotation.gap(sentenceIdx)); }
/// Returns a ByteRange representing wordIdx in sentenceIdx
ByteRange wordAsByteRange(size_t sentenceIdx, size_t wordIdx) const {
return annotation.word(sentenceIdx, wordIdx);
}
ByteRange wordAsByteRange(size_t sentenceIdx, size_t wordIdx) const { return annotation.word(sentenceIdx, wordIdx); }
/// Returns a ByteRange representing sentence corresponding to sentenceIdx.
ByteRange sentenceAsByteRange(size_t sentenceIdx) const {
return annotation.sentence(sentenceIdx);
}
ByteRange sentenceAsByteRange(size_t sentenceIdx) const { return annotation.sentence(sentenceIdx); }
private:
string_view asStringView(const ByteRange &byteRange) const {

View File

@ -1,4 +1,5 @@
#include "batch.h"
#include "request.h"
namespace marian {
@ -11,13 +12,10 @@ void Batch::log() {
maxLength = std::max(maxLength, static_cast<size_t>(sentence.numTokens()));
}
LOG(info, "Batch(tokens={}, max-length={}, sentences_={})", numTokens,
maxLength, sentences_.size());
LOG(info, "Batch(tokens={}, max-length={}, sentences_={})", numTokens, maxLength, sentences_.size());
}
void Batch::add(const RequestSentence &sentence) {
sentences_.push_back(sentence);
}
void Batch::add(const RequestSentence &sentence) { sentences_.push_back(sentence); }
void Batch::completeBatch(const Histories &histories) {
for (size_t i = 0; i < sentences_.size(); i++) {

View File

@ -1,38 +1,40 @@
#include "batch_translator.h"
#include "batch.h"
#include "byte_array_util.h"
#include "common/logging.h"
#include "data/corpus.h"
#include "data/text_input.h"
#include "translator/beam_search.h"
#include "byte_array_util.h"
namespace marian {
namespace bergamot {
BatchTranslator::BatchTranslator(DeviceId const device,
Vocabs &vocabs,
Ptr<Options> options,
const AlignedMemory* modelMemory,
const AlignedMemory* shortlistMemory)
: device_(device), options_(options), vocabs_(vocabs),
modelMemory_(modelMemory), shortlistMemory_(shortlistMemory) {}
BatchTranslator::BatchTranslator(DeviceId const device, Vocabs &vocabs, Ptr<Options> options,
const AlignedMemory *modelMemory, const AlignedMemory *shortlistMemory)
: device_(device),
options_(options),
vocabs_(vocabs),
modelMemory_(modelMemory),
shortlistMemory_(shortlistMemory) {}
void BatchTranslator::initialize() {
// Initializes the graph.
bool check = options_->get<bool>("check-bytearray",false); // Flag holds whether validate the bytearray (model and shortlist)
bool check =
options_->get<bool>("check-bytearray", false); // Flag holds whether validate the bytearray (model and shortlist)
if (options_->hasAndNotEmpty("shortlist")) {
int srcIdx = 0, trgIdx = 1;
bool shared_vcb = vocabs_.sources().front() == vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab
bool shared_vcb =
vocabs_.sources().front() ==
vocabs_.target(); // vocabs_->sources().front() is invoked as we currently only support one source vocab
if (shortlistMemory_->size() > 0 && shortlistMemory_->begin() != nullptr) {
slgen_ = New<data::BinaryShortlistGenerator>(shortlistMemory_->begin(), shortlistMemory_->size(),
vocabs_.sources().front(), vocabs_.target(),
srcIdx, trgIdx, shared_vcb, check);
}
else {
vocabs_.sources().front(), vocabs_.target(), srcIdx, trgIdx,
shared_vcb, check);
} else {
// Changed to BinaryShortlistGenerator to enable loading binary shortlist file
// This class also supports text shortlist file
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(),
vocabs_.target(), srcIdx,
slgen_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(), vocabs_.target(), srcIdx,
trgIdx, shared_vcb);
}
}
@ -43,14 +45,19 @@ void BatchTranslator::initialize() {
graph_->setDevice(device_);
graph_->getBackend()->configureDevice(options_);
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
if (modelMemory_->size() > 0 && modelMemory_->begin() != nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model from there, as opposed to from reading in the config file
if (modelMemory_->size() > 0 &&
modelMemory_->begin() !=
nullptr) { // If we have provided a byte array that contains the model memory, we can initialise the model
// from there, as opposed to from reading in the config file
ABORT_IF((uintptr_t)modelMemory_->begin() % 256 != 0,
"The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it.");
if (check) {
ABORT_IF(!validateBinaryModel(*modelMemory_, modelMemory_->size()),
"The binary file is invalid. Incomplete or corrupted download?");
}
const std::vector<const void *> container = {modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector. However we will only ever use 1 during decoding.
const std::vector<const void *> container = {
modelMemory_->begin()}; // Marian supports multiple models initialised in this manner hence std::vector.
// However we will only ever use 1 during decoding.
scorers_ = createScorers(options_, container);
} else {
scorers_ = createScorers(options_);
@ -82,11 +89,9 @@ void BatchTranslator::translate(Batch &batch) {
std::vector<size_t> sentenceIds;
std::vector<int> maxDims;
for (auto &ex : batchVector) {
if (maxDims.size() < ex.size())
maxDims.resize(ex.size(), 0);
if (maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0);
for (size_t i = 0; i < ex.size(); ++i) {
if (ex[i].size() > (size_t)maxDims[i])
maxDims[i] = (int)ex[i].size();
if (ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size();
}
sentenceIds.push_back(ex.getId());
}
@ -96,8 +101,7 @@ void BatchTranslator::translate(Batch &batch) {
std::vector<Ptr<SubBatch>> subBatches;
for (size_t j = 0; j < maxDims.size(); ++j) {
subBatches.emplace_back(
New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
}
std::vector<size_t> words(maxDims.size(), 0);
@ -111,8 +115,7 @@ void BatchTranslator::translate(Batch &batch) {
}
}
for (size_t j = 0; j < maxDims.size(); ++j)
subBatches[j]->setWords(words[j]);
for (size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]);
auto corpus_batch = Ptr<CorpusBatch>(new CorpusBatch(subBatches));
corpus_batch->setSentenceIds(sentenceIds);

View File

@ -32,11 +32,12 @@ public:
* @param device DeviceId that performs translation. Could be CPU or GPU
* @param vocabs Vector that contains ptrs to two vocabs
* @param options Marian options object
* @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not used.
* @param modelMemory byte array (aligned to 256!!!) that contains the bytes of a model.bin. Provide a nullptr if not
* used.
* @param shortlistMemory byte array of shortlist (aligned to 64)
*/
explicit BatchTranslator(DeviceId const device, Vocabs &vocabs,
Ptr<Options> options, const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory);
explicit BatchTranslator(DeviceId const device, Vocabs& vocabs, Ptr<Options> options,
const AlignedMemory* modelMemory, const AlignedMemory* shortlistMemory);
// convenience function for logging. TODO(jerin)
std::string _identifier() { return "worker" + std::to_string(device_.no); }

View File

@ -1,7 +1,9 @@
#include "batcher.h"
#include <cassert>
#include "batch.h"
#include "common/logging.h"
#include <cassert>
namespace marian {
namespace bergamot {

View File

@ -1,5 +1,7 @@
#include "byte_array_util.h"
#include <stdlib.h>
#include <iostream>
#include <memory>
@ -30,7 +32,8 @@ const T* get(const void*& current, uint64_t num = 1) {
bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize) {
const void* current = model.begin();
uint64_t memoryNeeded = sizeof(uint64_t)*2; // We keep track of how much memory we would need if we have a complete file
uint64_t memoryNeeded =
sizeof(uint64_t) * 2; // We keep track of how much memory we would need if we have a complete file
uint64_t numHeaders;
if (fileSize >= memoryNeeded) { // We have enough filesize to fetch the headers.
uint64_t binaryFileVersion = *get<uint64_t>(current);

View File

@ -1,5 +1,5 @@
#include "marian.h"
#include "definitions.h"
#include "marian.h"
namespace marian {
namespace bergamot {

View File

@ -1,10 +1,11 @@
#ifndef SRC_BERGAMOT_DEFINITIONS_H_
#define SRC_BERGAMOT_DEFINITIONS_H_
#include <vector>
#include "aligned.h"
#include "data/types.h"
#include "data/vocab_base.h"
#include "aligned.h"
#include <vector>
namespace marian {
namespace bergamot {

View File

@ -12,26 +12,21 @@ namespace bergamot {
inline marian::ConfigParser createConfigParser() {
marian::ConfigParser cp(marian::cli::mode::translation);
cp.addOption<std::string>(
"--ssplit-prefix-file", "Bergamot Options",
cp.addOption<std::string>("--ssplit-prefix-file", "Bergamot Options",
"File with nonbreaking prefixes for sentence splitting.");
cp.addOption<std::string>("--ssplit-mode", "Server Options",
"[paragraph, sentence, wrapped_text]", "paragraph");
cp.addOption<std::string>("--ssplit-mode", "Server Options", "[paragraph, sentence, wrapped_text]", "paragraph");
cp.addOption<int>(
"--max-length-break", "Bergamot Options",
cp.addOption<int>("--max-length-break", "Bergamot Options",
"Maximum input tokens to be processed in a single sentence.", 128);
cp.addOption<bool>(
"--check-bytearray", "Bergamot Options",
cp.addOption<bool>("--check-bytearray", "Bergamot Options",
"Flag holds whether to check the content of the bytearray (true by default)", true);
return cp;
}
inline std::shared_ptr<marian::Options>
parseOptions(const std::string &config, bool validate = true) {
inline std::shared_ptr<marian::Options> parseOptions(const std::string &config, bool validate = true) {
marian::Options options;
// @TODO(jerinphilip) There's something off here, @XapaJIaMnu suggests

View File

@ -1,23 +1,22 @@
#include "request.h"
#include "definitions.h"
#include "response.h"
#include "annotation.h"
#include "common/logging.h"
#include <string>
#include "annotation.h"
#include "common/logging.h"
#include "definitions.h"
#include "response.h"
namespace marian {
namespace bergamot {
// -----------------------------------------------------------------
Request::Request(size_t Id, Segments &&segments,
ResponseBuilder &&responseBuilder)
: Id_(Id), segments_(std::move(segments)),
Request::Request(size_t Id, Segments &&segments, ResponseBuilder &&responseBuilder)
: Id_(Id),
segments_(std::move(segments)),
responseBuilder_(std::move(responseBuilder))
{
counter_ = segments_.size();
histories_.resize(segments_.size(), nullptr);
@ -31,9 +30,7 @@ Request::Request(size_t Id, Segments &&segments,
size_t Request::numSegments() const { return segments_.size(); }
size_t Request::segmentTokens(size_t index) const {
return (segments_[index].size());
}
size_t Request::segmentTokens(size_t index) const { return (segments_[index].size()); }
Segment Request::getSegment(size_t index) const { return segments_[index]; }
@ -56,12 +53,9 @@ bool Request::operator<(const Request &b) const {
// ------------------------------------------------------------------
RequestSentence::RequestSentence(size_t index, Ptr<Request> request)
: index_(index), request_(request) {}
RequestSentence::RequestSentence(size_t index, Ptr<Request> request) : index_(index), request_(request) {}
size_t RequestSentence::numTokens() const {
return (request_->segmentTokens(index_));
}
size_t RequestSentence::numTokens() const { return (request_->segmentTokens(index_)); }
void RequestSentence::completeSentence(Ptr<History> history) {
// Relays completeSentence into request's processHistory, using index
@ -69,9 +63,7 @@ void RequestSentence::completeSentence(Ptr<History> history) {
request_->processHistory(index_, history);
}
Segment RequestSentence::getUnderlyingSegment() const {
return request_->getSegment(index_);
}
Segment RequestSentence::getUnderlyingSegment() const { return request_->getSegment(index_); }
bool operator<(const RequestSentence &a, const RequestSentence &b) {
// Operator overload for usage in priority-queue / set.

View File

@ -1,20 +1,18 @@
#ifndef SRC_BERGAMOT_REQUEST_H_
#define SRC_BERGAMOT_REQUEST_H_
#include <cassert>
#include <future>
#include <vector>
#include "annotation.h"
#include "common/logging.h"
#include "data/types.h"
#include "definitions.h"
#include "response.h"
#include "response_builder.h"
#include "annotation.h"
#include "common/logging.h"
#include "data/types.h"
#include "translator/beam_search.h"
#include <cassert>
#include <future>
#include <vector>
namespace marian {
namespace bergamot {
@ -95,7 +93,6 @@ private:
/// within Request, while batching mechanism (Batcher) compiles Batch from
/// RequestSentence-s coming from different Requests.
class RequestSentence {
public:
RequestSentence(size_t, Ptr<Request>);

View File

@ -1,16 +1,16 @@
#ifndef SRC_BERGAMOT_RESPONSE_H_
#define SRC_BERGAMOT_RESPONSE_H_
#include "data/alignment.h"
#include "data/types.h"
#include "definitions.h"
#include "annotation.h"
#include "translator/beam_search.h"
#include <cassert>
#include <string>
#include <vector>
#include "annotation.h"
#include "data/alignment.h"
#include "data/types.h"
#include "definitions.h"
#include "translator/beam_search.h"
namespace marian {
namespace bergamot {

View File

@ -1,11 +1,11 @@
#include "response_builder.h"
#include "response_options.h"
namespace marian {
namespace bergamot {
void ResponseBuilder::buildQualityScores(Histories &histories,
Response &response) {
void ResponseBuilder::buildQualityScores(Histories &histories, Response &response) {
std::vector<Quality> qualityScores;
for (auto &history : histories) {
// TODO(jerin): Change hardcode of nBest = 1
@ -20,13 +20,11 @@ void ResponseBuilder::buildQualityScores(Histories &histories,
auto normalizedPathScore = std::get<2>(result);
auto wordQualities = hyp->tracebackWordScores();
wordQualities.pop_back();
response.qualityScores.push_back(
Quality{normalizedPathScore, wordQualities});
response.qualityScores.push_back(Quality{normalizedPathScore, wordQualities});
}
}
void ResponseBuilder::buildAlignments(Histories &histories,
Response &response) {
void ResponseBuilder::buildAlignments(Histories &histories, Response &response) {
for (auto &history : histories) {
// TODO(jerin): Change hardcode of nBest = 1
NBestList onebest = history->nBest(1);
@ -40,8 +38,7 @@ void ResponseBuilder::buildAlignments(Histories &histories,
auto hyp = std::get<1>(result);
auto softAlignment = hyp->tracebackAlignment();
auto threshold = responseOptions_.alignmentThreshold;
auto hardAlignment =
data::ConvertSoftAlignToHardAlign(softAlignment, threshold);
auto hardAlignment = data::ConvertSoftAlignToHardAlign(softAlignment, threshold);
Alignment unified_alignment;
for (auto &p : hardAlignment) {
unified_alignment.emplace_back(Point{p.srcPos, p.tgtPos, p.prob});
@ -51,8 +48,7 @@ void ResponseBuilder::buildAlignments(Histories &histories,
}
}
void ResponseBuilder::buildTranslatedText(Histories &histories,
Response &response) {
void ResponseBuilder::buildTranslatedText(Histories &histories, Response &response) {
// Reserving length at least as much as source_ seems like a reasonable
// thing to do to avoid reallocations.
response.target.text.reserve(response.source.text.size());

View File

@ -24,11 +24,9 @@ public:
/// or not in the response and any additional configurable parameters.
/// @param [in] vocabs: marian vocab object (used in decoding)
/// @param [in] promise: promise to set with the constructed Response.
ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source,
Vocabs &vocabs,
ResponseBuilder(ResponseOptions responseOptions, AnnotatedText &&source, Vocabs &vocabs,
std::promise<Response> &&promise)
: responseOptions_(responseOptions), source_(std::move(source)),
vocabs_(vocabs), promise_(std::move(promise)) {}
: responseOptions_(responseOptions), source_(std::move(source)), vocabs_(vocabs), promise_(std::move(promise)) {}
/// Constructs and sets the promise of a Response object from obtained
/// histories after translating.
@ -38,8 +36,7 @@ public:
// TODO(jerinphilip) load ResponseOptions into options and turn build
// functions on or off.
// responseOptions_ is unused, but we can try something here.
ABORT_IF(source_.numSentences() != histories.size(),
"Mismatch in source and translated sentences");
ABORT_IF(source_.numSentences() != histories.size(), "Mismatch in source and translated sentences");
Response response;
// Move source_ into response.

View File

@ -1,48 +1,42 @@
#include "sentence_splitter.h"
#include <string>
#include "common/cli_helper.h"
#include "common/logging.h"
#include "common/options.h"
#include <string>
namespace marian {
namespace bergamot {
SentenceSplitter::SentenceSplitter(marian::Ptr<marian::Options> options)
: options_(options) {
SentenceSplitter::SentenceSplitter(marian::Ptr<marian::Options> options) : options_(options) {
std::string smode_str = options_->get<std::string>("ssplit-mode", "");
mode_ = string2splitmode(smode_str);
std::string ssplit_prefix_file =
options_->get<std::string>("ssplit-prefix-file", "");
std::string ssplit_prefix_file = options_->get<std::string>("ssplit-prefix-file", "");
if (ssplit_prefix_file.size()) {
ssplit_prefix_file = marian::cli::interpolateEnvVars(ssplit_prefix_file);
LOG(info, "Loading protected prefixes for sentence splitting from {}",
ssplit_prefix_file);
LOG(info, "Loading protected prefixes for sentence splitting from {}", ssplit_prefix_file);
ssplit_.load(ssplit_prefix_file);
} else {
LOG(warn, "Missing list of protected prefixes for sentence splitting. "
LOG(warn,
"Missing list of protected prefixes for sentence splitting. "
"Set with --ssplit-prefix-file.");
}
}
ug::ssplit::SentenceStream
SentenceSplitter::createSentenceStream(const string_view &input) {
ug::ssplit::SentenceStream SentenceSplitter::createSentenceStream(const string_view &input) {
std::string_view input_converted(input.data(), input.size());
return std::move(
ug::ssplit::SentenceStream(input_converted, this->ssplit_, mode_));
return std::move(ug::ssplit::SentenceStream(input_converted, this->ssplit_, mode_));
}
ug::ssplit::SentenceStream::splitmode
SentenceSplitter::string2splitmode(const std::string &m) {
ug::ssplit::SentenceStream::splitmode SentenceSplitter::string2splitmode(const std::string &m) {
typedef ug::ssplit::SentenceStream::splitmode splitmode;
// @TODO: throw Exception on error
if (m == "sentence" || m == "Sentence")
return splitmode::one_sentence_per_line;
if (m == "paragraph" || m == "Paragraph")
return splitmode::one_paragraph_per_line;
if (m == "sentence" || m == "Sentence") return splitmode::one_sentence_per_line;
if (m == "paragraph" || m == "Paragraph") return splitmode::one_paragraph_per_line;
if (m != "wrapped_text" && m != "WrappedText" && m != "wrappedText") {
LOG(warn, "Ignoring unknown text input format specification: {}.", m);
}

View File

@ -1,11 +1,12 @@
#ifndef SRC_BERGAMOT_SENTENCE_SPLITTER_H_
#define SRC_BERGAMOT_SENTENCE_SPLITTER_H_
#include <string>
#include "common/options.h"
#include "data/types.h"
#include "ssplit.h"
#include "definitions.h"
#include <string>
#include "ssplit.h"
namespace marian {
namespace bergamot {

View File

@ -1,17 +1,20 @@
#include "service.h"
#include "batch.h"
#include "definitions.h"
#include <string>
#include <utility>
#include "batch.h"
#include "definitions.h"
namespace marian {
namespace bergamot {
Service::Service(Ptr<Options> options, MemoryBundle memoryBundle)
: requestId_(0), options_(options),
: requestId_(0),
options_(options),
vocabs_(options, std::move(memoryBundle.vocabs)),
text_processor_(vocabs_, options), batcher_(options),
text_processor_(vocabs_, options),
batcher_(options),
numWorkers_(options->get<int>("cpu-threads")),
modelMemory_(std::move(memoryBundle.model)),
shortlistMemory_(std::move(memoryBundle.shortlist))
@ -41,9 +44,7 @@ void Service::build_translators(Ptr<Options> options, size_t numTranslators) {
}
}
void Service::initialize_blocking_translator() {
translators_.back().initialize();
}
void Service::initialize_blocking_translator() { translators_.back().initialize(); }
void Service::blocking_translate() {
Batch batch;
@ -84,13 +85,9 @@ void Service::async_translate() {
}
}
#else // WASM_COMPATIBLE_SOURCE
void Service::initialize_async_translators() {
ABORT("Cannot run in async mode without multithreading.");
}
void Service::initialize_async_translators() { ABORT("Cannot run in async mode without multithreading."); }
void Service::async_translate() {
ABORT("Cannot run in async mode without multithreading.");
}
void Service::async_translate() { ABORT("Cannot run in async mode without multithreading."); }
#endif // WASM_COMPATIBLE_SOURCE
std::future<Response> Service::translate(std::string &&input) {
@ -98,16 +95,12 @@ std::future<Response> Service::translate(std::string &&input) {
return translate(std::move(input), responseOptions);
}
std::vector<Response>
Service::translateMultiple(std::vector<std::string> &&inputs,
ResponseOptions responseOptions) {
std::vector<Response> Service::translateMultiple(std::vector<std::string> &&inputs, ResponseOptions responseOptions) {
// We queue the individual Requests so they get compiled at batches to be
// efficiently translated.
std::vector<std::future<Response>> responseFutures;
for (auto &input : inputs) {
std::future<Response> inputResponse =
queueRequest(std::move(input), responseOptions);
std::future<Response> inputResponse = queueRequest(std::move(input), responseOptions);
responseFutures.push_back(std::move(inputResponse));
}
@ -126,8 +119,7 @@ Service::translateMultiple(std::vector<std::string> &&inputs,
return responses;
}
std::future<Response> Service::queueRequest(std::string &&input,
ResponseOptions responseOptions) {
std::future<Response> Service::queueRequest(std::string &&input, ResponseOptions responseOptions) {
Segments segments;
AnnotatedText source(std::move(input));
text_processor_.process(source, segments);
@ -135,19 +127,15 @@ std::future<Response> Service::queueRequest(std::string &&input,
std::promise<Response> responsePromise;
auto future = responsePromise.get_future();
ResponseBuilder responseBuilder(responseOptions, std::move(source), vocabs_,
std::move(responsePromise));
Ptr<Request> request = New<Request>(requestId_++, std::move(segments),
std::move(responseBuilder));
ResponseBuilder responseBuilder(responseOptions, std::move(source), vocabs_, std::move(responsePromise));
Ptr<Request> request = New<Request>(requestId_++, std::move(segments), std::move(responseBuilder));
batcher_.addWholeRequest(request);
return future;
}
std::future<Response> Service::translate(std::string &&input,
ResponseOptions responseOptions) {
std::future<Response> future =
queueRequest(std::move(input), responseOptions);
std::future<Response> Service::translate(std::string &&input, ResponseOptions responseOptions) {
std::future<Response> future = queueRequest(std::move(input), responseOptions);
dispatchTranslate();
return future;
}
@ -163,7 +151,6 @@ void Service::dispatchTranslate() {
Service::~Service() {
#ifndef WASM_COMPATIBLE_SOURCE
for (size_t workerId = 0; workerId < numWorkers_; workerId++) {
Batch poison = Batch::poison();
pcqueue_.ProduceSwap(poison);
}

View File

@ -60,7 +60,6 @@ namespace bergamot {
/// file supplied through config).
///
class Service {
public:
/// Construct Service from Marian options. If memoryBundle is empty, Service is
/// initialized from file-based loading. Otherwise, Service is initialized from
@ -97,8 +96,7 @@ public:
/// @param [in] responseOptions: Options indicating whether or not to include
/// some member in the Response, also specify any additional configurable
/// parameters.
std::future<Response> translate(std::string &&source,
ResponseOptions options);
std::future<Response> translate(std::string &&source, ResponseOptions options);
/// Translate multiple text-blobs in a single *blocking* API call, providing
/// ResponseOptions which applies across all text-blobs dictating how to
@ -117,19 +115,14 @@ public:
/// to include some member in the Response, also specify any additional
/// configurable parameters.
std::vector<Response>
translateMultiple(std::vector<std::string> &&source,
ResponseOptions responseOptions);
std::vector<Response> translateMultiple(std::vector<std::string> &&source, ResponseOptions responseOptions);
/// Returns if model is alignment capable or not.
bool isAlignmentSupported() const {
return options_->hasAndNotEmpty("alignment");
}
bool isAlignmentSupported() const { return options_->hasAndNotEmpty("alignment"); }
private:
/// Queue an input for translation.
std::future<Response> queueRequest(std::string &&input,
ResponseOptions responseOptions);
std::future<Response> queueRequest(std::string &&input, ResponseOptions responseOptions);
/// Dispatch call to translate after inserting in queue
void dispatchTranslate();
@ -164,8 +157,7 @@ private:
/// Holds instances of batch translators, just one in case
/// of single-threaded application, numWorkers_ in case of multithreaded
/// setting.
std::vector<BatchTranslator>
translators_; // ORDER DEPENDENCY (modelMemory_, shortlistMemory_)
std::vector<BatchTranslator> translators_; // ORDER DEPENDENCY (modelMemory_, shortlistMemory_)
/// Stores requestId of active request. Used to establish
/// ordering among requests and logging/book-keeping.

View File

@ -1,39 +1,33 @@
#include "text_processor.h"
#include <vector>
#include "annotation.h"
#include "common/options.h"
#include "data/types.h"
#include "definitions.h"
#include "annotation.h"
#include "common/options.h"
#include <vector>
namespace marian {
namespace bergamot {
Segment TextProcessor::tokenize(const string_view &segment,
std::vector<string_view> &wordRanges) {
Segment TextProcessor::tokenize(const string_view &segment, std::vector<string_view> &wordRanges) {
// vocabs_->sources().front() is invoked as we currently only support one source vocab
return vocabs_.sources().front()->encodeWithByteRanges(
segment, wordRanges, /*addEOS=*/false, /*inference=*/true);
return vocabs_.sources().front()->encodeWithByteRanges(segment, wordRanges, /*addEOS=*/false, /*inference=*/true);
}
TextProcessor::TextProcessor(Vocabs &vocabs,
Ptr<Options> options)
: vocabs_(vocabs), sentence_splitter_(options) {
TextProcessor::TextProcessor(Vocabs &vocabs, Ptr<Options> options) : vocabs_(vocabs), sentence_splitter_(options) {
max_length_break_ = options->get<int>("max-length-break");
max_length_break_ = max_length_break_ - 1;
ABORT_IF(max_length_break_ < 0, "max-length-break cannot be < 0");
}
void TextProcessor::process(AnnotatedText &source, Segments &segments) {
string_view query = string_view(source.text);
auto sentenceStream = sentence_splitter_.createSentenceStream(query);
std::string_view sentenceStringPiece;
while (sentenceStream >> sentenceStringPiece) {
marian::string_view sentence(sentenceStringPiece.data(),
sentenceStringPiece.size());
marian::string_view sentence(sentenceStringPiece.data(), sentenceStringPiece.size());
std::vector<string_view> wordRanges;
Segment segment = tokenize(sentence, wordRanges);
@ -48,11 +42,9 @@ void TextProcessor::process(AnnotatedText &source, Segments &segments) {
}
}
void TextProcessor::wrap(Segment &segment,
std::vector<string_view> &wordRanges,
Segments &segments, AnnotatedText &source) {
for (size_t offset = 0; offset < segment.size();
offset += max_length_break_) {
void TextProcessor::wrap(Segment &segment, std::vector<string_view> &wordRanges, Segments &segments,
AnnotatedText &source) {
for (size_t offset = 0; offset < segment.size(); offset += max_length_break_) {
auto start = segment.begin() + offset;
size_t left = segment.size() - offset;

View File

@ -1,16 +1,15 @@
#ifndef SRC_BERGAMOT_TEXT_PROCESSOR_H_
#define SRC_BERGAMOT_TEXT_PROCESSOR_H_
#include <vector>
#include "annotation.h"
#include "data/types.h"
#include "data/vocab.h"
#include "definitions.h"
#include "annotation.h"
#include "sentence_splitter.h"
#include "vocabs.h"
#include <vector>
namespace marian {
namespace bergamot {
@ -29,12 +28,10 @@ public:
private:
// Tokenizes an input string, returns Words corresponding. Loads the
// corresponding byte-ranges into tokenRanges.
Segment tokenize(const string_view &input,
std::vector<string_view> &tokenRanges);
Segment tokenize(const string_view &input, std::vector<string_view> &tokenRanges);
// Wrap into sentences of at most max_length_break_ tokens and add to source.
void wrap(Segment &sentence, std::vector<string_view> &tokenRanges,
Segments &segments, AnnotatedText &source);
void wrap(Segment &sentence, std::vector<string_view> &tokenRanges, Segments &segments, AnnotatedText &source);
// shorthand, used only in truncate()
// vocabs_->sources().front() is invoked as we currently only support one source vocab

View File

@ -12,8 +12,7 @@ public:
if (!vocabMemories.empty()) {
// load vocabs from buffer
load(std::move(vocabMemories));
}
else{
} else {
// load vocabs from file
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");
load(vocabPaths);
@ -21,14 +20,10 @@ public:
}
/// Get all source vocabularies (as a vector)
const std::vector<Ptr<Vocab const>>& sources() const {
return srcVocabs_;
}
const std::vector<Ptr<Vocab const>>& sources() const { return srcVocabs_; }
/// Get the target vocabulary
const Ptr<Vocab const>& target() const {
return trgVocab_;
}
const Ptr<Vocab const>& target() const { return trgVocab_; }
private:
std::vector<Ptr<Vocab const>> srcVocabs_; // source vocabularies

View File

@ -23,8 +23,7 @@ EMSCRIPTEN_BINDINGS(aligned_memory) {
class_<AlignedMemory>("AlignedMemory")
.constructor<std::size_t, std::size_t>()
.function("size", &AlignedMemory::size)
.function("getByteArrayView", &getByteArrayView)
;
.function("getByteArrayView", &getByteArrayView);
register_vector<AlignedMemory*>("AlignedMemoryList");
}
@ -41,15 +40,13 @@ std::vector<std::shared_ptr<AlignedMemory>> prepareVocabsSmartMemories(std::vect
if (vocabsMemories.size() == 2) {
auto targetVocabMemory = std::make_shared<AlignedMemory>(std::move(*(vocabsMemories[1])));
vocabsSmartMemories.push_back(std::move(targetVocabMemory));
}
else {
} else {
vocabsSmartMemories.push_back(sourceVocabMemory);
}
return vocabsSmartMemories;
}
marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory,
AlignedMemory* shortlistMemory,
marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
marian::bergamot::MemoryBundle memoryBundle;
memoryBundle.model = std::move(*modelMemory);
@ -59,19 +56,18 @@ marian::bergamot::MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory,
return memoryBundle;
}
TranslationModel* TranslationModelFactory(const std::string &config,
AlignedMemory* modelMemory,
TranslationModel* TranslationModelFactory(const std::string& config, AlignedMemory* modelMemory,
AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
return new TranslationModel(config, std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
return new TranslationModel(config,
std::move(prepareMemoryBundle(modelMemory, shortlistMemory, uniqueVocabsMemories)));
}
EMSCRIPTEN_BINDINGS(translation_model) {
class_<TranslationModel>("TranslationModel")
.constructor(&TranslationModelFactory, allow_raw_pointers())
.function("translate", &TranslationModel::translateMultiple)
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported)
;
.function("isAlignmentSupported", &TranslationModel::isAlignmentSupported);
// ^ We redirect Service::translateMultiple to WASMBound::translate instead. Sane API is
// translate. If and when async comes, we can be done with this inconsistency.

View File

@ -12,8 +12,4 @@ typedef marian::bergamot::ResponseOptions TranslationRequest;
using namespace emscripten;
// Binding code
EMSCRIPTEN_BINDINGS(translation_request) {
class_<TranslationRequest>("TranslationRequest")
.constructor<>()
;
}
EMSCRIPTEN_BINDINGS(translation_request) { class_<TranslationRequest>("TranslationRequest").constructor<>(); }

View File

@ -4,6 +4,7 @@
*/
#include <emscripten/bind.h>
#include <vector>
#include "response.h"