Batch cleanup

Moves Batch into batch.{h,cpp}.

- Id_ no longer used due to overflow concerns. (#27)
- size_t for places where signed integer is not preferred.
- Adjustments to response.{h,cpp}
This commit is contained in:
Jerin Philip 2021-02-16 19:46:40 +00:00
parent 65e7406970
commit 4c8b655ac5
12 changed files with 109 additions and 104 deletions

View File

@ -11,6 +11,7 @@ add_library(bergamot-translator STATIC
service.cpp
batcher.cpp
response.cpp
batch.cpp
)
target_link_libraries(bergamot-translator marian ssplit)

28
src/translator/batch.cpp Normal file
View File

@ -0,0 +1,28 @@
#include "batch.h"
#include "request.h"
namespace marian {
namespace bergamot {
void Batch::log() {
size_t numTokens{0}, maxLength{0};
for (auto &sentence : sentences_) {
numTokens += sentence.numTokens();
maxLength = std::max(maxLength, static_cast<size_t>(sentence.numTokens()));
}
LOG(info, "Batch(tokens={}, max-length={}, sentences_={})", numTokens,
maxLength, sentences_.size());
}
void Batch::add(const RequestSentence &sentence) {
sentences_.push_back(sentence);
}
void Batch::completeBatch(const Histories &histories) {
for (int i = 0; i < sentences_.size(); i++) {
sentences_[i].completeSentence(histories[i]);
}
}
} // namespace bergamot
} // namespace marian

52
src/translator/batch.h Normal file
View File

@ -0,0 +1,52 @@
#ifndef SRC_BERGAMOT_BATCH_H
#define SRC_BERGAMOT_BATCH_H
#include "request.h"
#include "translator/beam_search.h"
namespace marian {
namespace bergamot {
class Batch {
public:
Batch() {}
void clear() { sentences_.clear(); }
// Methods to construct and determine poison.
static Batch poison() {
Batch batch;
batch.poison_ = true;
return batch;
}
bool isPoison() const { return poison_; }
size_t size() const { return sentences_.size(); }
void add(const RequestSentence &sentence);
// Accessors to read from a Batch. For use in BatchTranslator (consumer on a
// PCQueue holding batches).
//
// sentences() are used to access sentences to construct marian internal
// batch.
const RequestSentences &sentences() { return sentences_; }
// On obtaining Histories after translating a batch, completeBatch can be
// called with Histories , which forwards the call to Request through
// RequestSentence and triggers completion, by setting the promised value to
// the future given to client.
void completeBatch(const Histories &histories);
// Convenience function to log batch-statistics. numTokens, max-length.
void log();
private:
bool poison_{false};
RequestSentences sentences_;
};
} // namespace bergamot
} // namespace marian
#endif // SRC_BERGAMOT_BATCH_H_

View File

@ -1,4 +1,5 @@
#include "batch_translator.h"
#include "batch.h"
#include "common/logging.h"
#include "data/corpus.h"
#include "data/text_input.h"

View File

@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include "batch.h"
#include "common/utils.h"
#include "data/shortlist.h"
#include "definitions.h"

View File

@ -1,4 +1,5 @@
#include "batcher.h"
#include "batch.h"
#include "common/logging.h"
#include <cassert>
@ -27,7 +28,7 @@ bool Batcher::cleaveBatch(Batch &batch) {
// has to be enhanced with optimizing over priority. The baseline
// implementation should at least be as fast as marian's maxi-batch with full
// corpus size as maxi-batch size.
batch.reset();
batch.clear();
int paddedBatchSize = 0;
for (int length = 0; length < bucket_.size(); length++) {
@ -41,18 +42,13 @@ bool Batcher::cleaveBatch(Batch &batch) {
} else {
// Check if elements exist
assert(batch.size() > 0);
batch.setId(++batchNumber_);
return true;
}
}
}
if (batch.size()) {
batch.setId(++batchNumber_);
return true;
} else {
return false;
}
bool isValidBatch = batch.size() > 0;
return isValidBatch;
}
void Batcher::addWholeRequest(Ptr<Request> request) {

View File

@ -1,6 +1,7 @@
#ifndef SRC_BERGAMOT_BATCHER_H_
#define SRC_BERGAMOT_BATCHER_H_
#include "batch.h"
#include "common/options.h"
#include "data/corpus_base.h"
#include "definitions.h"

View File

@ -92,39 +92,5 @@ bool operator<(const RequestSentence &a, const RequestSentence &b) {
// ----------------------------------------------------------------------
void Batch::reset() {
Id_ = 0;
sentences_.clear();
}
void Batch::log() {
int numTokens{0}, maxLength{0};
for (auto &sentence : sentences_) {
numTokens += sentence.numTokens();
maxLength = std::max(maxLength, static_cast<int>(sentence.numTokens()));
}
LOG(info, "Batch(Id_={}, tokens={}, max-length={}, sentences_={})", Id_,
numTokens, maxLength, sentences_.size());
}
void Batch::add(const RequestSentence &sentence) {
sentences_.push_back(sentence);
}
void Batch::setId(int Id) {
assert(Id > 0);
Id_ = Id;
if (Id % 500 == 0) {
log();
}
}
void Batch::completeBatch(const Histories &histories) {
for (int i = 0; i < sentences_.size(); i++) {
sentences_[i].completeSentence(histories[i]);
}
}
} // namespace bergamot
} // namespace marian

View File

@ -13,9 +13,6 @@
// batching mechanism access to the segment within the request. The backref to
// Request allows event triggering the barrier upon completion of the last
// sentence by a worker.
//
// Batch: is a vector of RequestSentences tagged with a batchNumber, which is
// what the PCQueue holds. Batch is "produced" by the Batcher.
#ifndef SRC_BERGAMOT_REQUEST_H_
#define SRC_BERGAMOT_REQUEST_H_
@ -122,57 +119,6 @@ private:
typedef std::vector<RequestSentence> RequestSentences;
class Batch {
public:
Batch() { reset(); }
// Reset is required to reuse the same batch by consumer.
void reset();
// Methods to construct and determine poison.
static Batch poison() {
Batch poison_;
poison_.Id_ = -1;
return poison_;
}
bool isPoison() const { return (Id_ == -1); }
size_t size() const { return sentences_.size(); }
// Accessors to load data into a batch. Use add(...) to add sentences into a
// batch. Once complete with a legal batch, use setId to set Id_ accordingly.
// setId only allows setting Id > 0. For use in Batcher, which acts as a
// producer to a PCQueue holding "Batch"es.
//
// Id_ =
// -1 : Batch::Poison
// 0 : Empty Batch
// >0 : Legal batch containing sentences
void add(const RequestSentence &sentence);
void setId(int Id);
// Accessors to read from a Batch. For use in BatchTranslator (consumer on a
// PCQueue holding batches).
//
// sentences() are used to access sentences to construct marian internal
// batch.
const RequestSentences &sentences() { return sentences_; }
// On obtaining Histories after translating a batch, completeBatch can be
// called with Histories , which forwards the call to Request through
// RequestSentence and triggers completion, by setting the promised value to
// the future given to client.
void completeBatch(const Histories &histories);
// Convenience function to log batch-statistics. numTokens, max-length.
// TODO(jerinphilip): Use to log and report packing efficiency.
void log();
private:
int Id_;
RequestSentences sentences_;
};
} // namespace bergamot
} // namespace marian

View File

@ -16,7 +16,11 @@ Response::Response(std::string &&source,
void Response::move(std::string &source, std::string &translation,
SentenceMappings &sentenceMappings) {
// Construct required stuff first.
constructTranslation();
constructSentenceMappings(sentenceMappings);
// Move content out.
source = std::move(source_);
translation = std::move(translation_);
@ -28,6 +32,13 @@ void Response::move(std::string &source, std::string &translation,
}
void Response::constructTranslation() {
if (translationConstructed_) {
return;
}
// Reserving length at least as much as source_ seems like a reasonable thing
// to do to avoid reallocations.
translation_.reserve(source_.size());
// In a first step, the decoded units (individual senteneces) are compiled
// into a huge string. This is done by computing indices first and appending
@ -43,7 +54,8 @@ void Response::constructTranslation() {
Result result = onebest[0]; // Expecting only one result;
Words words = std::get<0>(result);
std::string decoded = vocabs_->back()->decode(words);
auto targetVocab = vocabs_->back();
std::string decoded = targetVocab->decode(words);
if (first) {
first = false;
} else {
@ -67,9 +79,10 @@ void Response::constructTranslation() {
const char *begin = &translation_[range.first];
targetMappings.emplace_back(begin, range.second);
targetRanges_.push_back(std::move(targetMappings));
}
translationConstructed_ = true;
}
void Response::constructSentenceMappings(

View File

@ -43,7 +43,8 @@ public:
translation_(std::move(other.translation_)),
sourceRanges_(std::move(other.sourceRanges_)),
targetRanges_(std::move(other.targetRanges_)),
histories_(std::move(other.histories_)){};
histories_(std::move(other.histories_)),
vocabs_(std::move(other.vocabs_)){};
// Prevents CopyConstruction and CopyAssignment. sourceRanges_ is constituted
// by string_view and copying invalidates the data member.
@ -66,9 +67,7 @@ public:
const Histories &histories() const { return histories_; }
const std::string &source() const { return source_; }
const std::string &translation() {
if (!translationConstructed) {
constructTranslation();
}
constructTranslation();
return translation_;
}
@ -88,8 +87,8 @@ private:
std::vector<TokenRanges> sourceRanges_;
Histories histories_;
std::vector<Ptr<Vocab const>> *vocabs_{nullptr};
bool translationConstructed{false};
std::vector<Ptr<Vocab const>> *vocabs_;
bool translationConstructed_{false};
std::string translation_;
std::vector<TokenRanges> targetRanges_;
};

View File

@ -1,4 +1,5 @@
#include "service.h"
#include "batch.h"
#include "definitions.h"
#include <string>