mirror of
https://github.com/browsermt/bergamot-translator.git
synced 2024-09-17 16:47:18 +03:00
single-threaded run with --cpu-threads 0 (#10)
This commit is contained in:
parent
4764f11e95
commit
f1d9f67b56
@ -36,8 +36,7 @@ BatchTranslator::BatchTranslator(DeviceId const device,
|
||||
graph_->forward();
|
||||
}
|
||||
|
||||
void BatchTranslator::translate(RequestSentences &requestSentences,
|
||||
Histories &histories) {
|
||||
void BatchTranslator::translate(RequestSentences &requestSentences) {
|
||||
std::vector<data::SentenceTuple> batchVector;
|
||||
|
||||
for (auto &sentence : requestSentences) {
|
||||
@ -89,7 +88,10 @@ void BatchTranslator::translate(RequestSentences &requestSentences,
|
||||
auto trgVocab = vocabs_->back();
|
||||
auto search = New<BeamSearch>(options_, scorers_, trgVocab);
|
||||
|
||||
histories = std::move(search->search(graph_, batch));
|
||||
auto histories = std::move(search->search(graph_, batch));
|
||||
for (int i = 0; i < requestSentences.size(); i++) {
|
||||
requestSentences[i].completeSentence(histories[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// void BatchTranslator::join() { thread_.join(); }
|
||||
@ -107,10 +109,7 @@ void translation_loop(DeviceId const &device, PCQueue<PCItem> &pcqueue,
|
||||
if (pcitem.isPoison()) {
|
||||
return;
|
||||
} else {
|
||||
translator.translate(pcitem.sentences, histories);
|
||||
for (int i = 0; i < pcitem.sentences.size(); i++) {
|
||||
pcitem.sentences[i].completeSentence(histories[i]);
|
||||
}
|
||||
translator.translate(pcitem.sentences);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ public:
|
||||
|
||||
// convenience function for logging. TODO(jerin)
|
||||
std::string _identifier() { return "worker" + std::to_string(device_.no); }
|
||||
void translate(RequestSentences &requestSentences, Histories &histories);
|
||||
void translate(RequestSentences &requestSentences);
|
||||
|
||||
private:
|
||||
Ptr<Options> options_;
|
||||
|
@ -50,5 +50,30 @@ void Batcher::cleaveBatch(RequestSentences &sentences) {
|
||||
}
|
||||
}
|
||||
|
||||
void Batcher::addWholeRequest(Ptr<Request> request) {
|
||||
for (int i = 0; i < request->numSegments(); i++) {
|
||||
RequestSentence requestSentence(i, request);
|
||||
addSentenceWithPriority(requestSentence);
|
||||
}
|
||||
}
|
||||
|
||||
void Batcher::enqueue(PCQueue<PCItem> &pcqueue) {
|
||||
int numSentences;
|
||||
do {
|
||||
RequestSentences batchSentences;
|
||||
cleaveBatch(batchSentences);
|
||||
numSentences = batchSentences.size();
|
||||
|
||||
if (numSentences > 0) {
|
||||
PCItem pcitem(batchNumber_++, std::move(batchSentences));
|
||||
pcqueue.ProduceSwap(pcitem);
|
||||
}
|
||||
|
||||
if (batchNumber_ % 500 == 0) {
|
||||
LOG(info, "Queuing batch {}", batchNumber_);
|
||||
}
|
||||
} while (numSentences > 0);
|
||||
}
|
||||
|
||||
} // namespace bergamot
|
||||
} // namespace marian
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "common/options.h"
|
||||
#include "data/corpus_base.h"
|
||||
#include "definitions.h"
|
||||
#include "pcqueue.h"
|
||||
#include "request.h"
|
||||
|
||||
#include <set>
|
||||
@ -19,6 +20,8 @@ public:
|
||||
// sentence. This method inserts the sentence into the internal data-structure
|
||||
// which maintains priority among sentences from multiple concurrent requests.
|
||||
void addSentenceWithPriority(RequestSentence &sentence);
|
||||
void addWholeRequest(Ptr<Request> request);
|
||||
void enqueue(PCQueue<PCItem> &pcqueue);
|
||||
|
||||
// Loads sentences with sentences compiled from (tentatively) multiple
|
||||
// requests optimizing for both padding and priority.
|
||||
@ -27,6 +30,7 @@ public:
|
||||
private:
|
||||
unsigned int max_input_tokens_;
|
||||
std::vector<std::set<RequestSentence>> bucket_;
|
||||
unsigned int batchNumber_{0};
|
||||
};
|
||||
|
||||
} // namespace bergamot
|
||||
|
@ -14,13 +14,17 @@ Service::Service(Ptr<Options> options)
|
||||
text_processor_(vocabs_, options), batcher_(options),
|
||||
pcqueue_(2 * options->get<int>("cpu-threads")) {
|
||||
|
||||
workers_.reserve(numWorkers_);
|
||||
|
||||
for (int cpuId = 0; cpuId < numWorkers_; cpuId++) {
|
||||
workers_.emplace_back([&] {
|
||||
marian::DeviceId deviceId(cpuId, DeviceType::cpu);
|
||||
translation_loop(deviceId, pcqueue_, vocabs_, options);
|
||||
});
|
||||
if (numWorkers_ > 0) {
|
||||
workers_.reserve(numWorkers_);
|
||||
for (int cpuId = 0; cpuId < numWorkers_; cpuId++) {
|
||||
workers_.emplace_back([&] {
|
||||
marian::DeviceId deviceId(cpuId, DeviceType::cpu);
|
||||
translation_loop(deviceId, pcqueue_, vocabs_, options);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
marian::DeviceId deviceId(/*cpuId=*/0, DeviceType::cpu);
|
||||
translator = new BatchTranslator(deviceId, vocabs_, options);
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,27 +57,28 @@ std::future<TranslationResult> Service::translate(std::string &&input) {
|
||||
std::move(segments), std::move(sourceAlignments),
|
||||
std::move(translationResultPromise));
|
||||
|
||||
for (int i = 0; i < request->numSegments(); i++) {
|
||||
RequestSentence requestSentence(i, request);
|
||||
batcher_.addSentenceWithPriority(requestSentence);
|
||||
batcher_.addWholeRequest(request);
|
||||
if (numWorkers_ > 0) {
|
||||
batcher_.enqueue(pcqueue_);
|
||||
} else {
|
||||
// Queue single-threaded
|
||||
int numSentences;
|
||||
do {
|
||||
RequestSentences batchSentences;
|
||||
batcher_.cleaveBatch(batchSentences);
|
||||
numSentences = batchSentences.size();
|
||||
|
||||
if (numSentences > 0) {
|
||||
translator->translate(batchSentences);
|
||||
batchNumber_++;
|
||||
}
|
||||
|
||||
if (batchNumber_ % 500 == 0) {
|
||||
LOG(info, "Tranlsating batch {}", batchNumber_);
|
||||
}
|
||||
} while (numSentences > 0);
|
||||
}
|
||||
|
||||
int numSentences;
|
||||
do {
|
||||
RequestSentences batchSentences;
|
||||
batcher_.cleaveBatch(batchSentences);
|
||||
numSentences = batchSentences.size();
|
||||
|
||||
if (numSentences > 0) {
|
||||
PCItem pcitem(batchNumber_++, std::move(batchSentences));
|
||||
pcqueue_.ProduceSwap(pcitem);
|
||||
}
|
||||
|
||||
if (batchNumber_ % 500 == 0) {
|
||||
LOG(info, "Queuing batch {}", batchNumber_);
|
||||
}
|
||||
} while (numSentences > 0);
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
|
@ -70,6 +70,9 @@ private:
|
||||
Batcher batcher_;
|
||||
PCQueue<PCItem> pcqueue_;
|
||||
std::vector<std::thread> workers_;
|
||||
|
||||
// Optional
|
||||
BatchTranslator *translator{nullptr};
|
||||
};
|
||||
|
||||
std::vector<Ptr<const Vocab>> loadVocabularies(Ptr<Options> options);
|
||||
|
Loading…
Reference in New Issue
Block a user