single-threaded run with --cpu-threads 0 (#10)

This commit is contained in:
Jerin Philip 2021-02-13 11:42:57 +00:00
parent 4764f11e95
commit f1d9f67b56
6 changed files with 70 additions and 34 deletions

View File

@ -36,8 +36,7 @@ BatchTranslator::BatchTranslator(DeviceId const device,
graph_->forward();
}
void BatchTranslator::translate(RequestSentences &requestSentences,
Histories &histories) {
void BatchTranslator::translate(RequestSentences &requestSentences) {
std::vector<data::SentenceTuple> batchVector;
for (auto &sentence : requestSentences) {
@ -89,7 +88,10 @@ void BatchTranslator::translate(RequestSentences &requestSentences,
auto trgVocab = vocabs_->back();
auto search = New<BeamSearch>(options_, scorers_, trgVocab);
histories = std::move(search->search(graph_, batch));
auto histories = std::move(search->search(graph_, batch));
for (int i = 0; i < requestSentences.size(); i++) {
requestSentences[i].completeSentence(histories[i]);
}
}
// void BatchTranslator::join() { thread_.join(); }
@ -107,10 +109,7 @@ void translation_loop(DeviceId const &device, PCQueue<PCItem> &pcqueue,
if (pcitem.isPoison()) {
return;
} else {
translator.translate(pcitem.sentences, histories);
for (int i = 0; i < pcitem.sentences.size(); i++) {
pcitem.sentences[i].completeSentence(histories[i]);
}
translator.translate(pcitem.sentences);
}
}
}

View File

@ -27,7 +27,7 @@ public:
// convenience function for logging. TODO(jerin)
std::string _identifier() { return "worker" + std::to_string(device_.no); }
void translate(RequestSentences &requestSentences, Histories &histories);
void translate(RequestSentences &requestSentences);
private:
Ptr<Options> options_;

View File

@ -50,5 +50,30 @@ void Batcher::cleaveBatch(RequestSentences &sentences) {
}
}
void Batcher::addWholeRequest(Ptr<Request> request) {
for (int i = 0; i < request->numSegments(); i++) {
RequestSentence requestSentence(i, request);
addSentenceWithPriority(requestSentence);
}
}
void Batcher::enqueue(PCQueue<PCItem> &pcqueue) {
int numSentences;
do {
RequestSentences batchSentences;
cleaveBatch(batchSentences);
numSentences = batchSentences.size();
if (numSentences > 0) {
PCItem pcitem(batchNumber_++, std::move(batchSentences));
pcqueue.ProduceSwap(pcitem);
}
if (batchNumber_ % 500 == 0) {
LOG(info, "Queuing batch {}", batchNumber_);
}
} while (numSentences > 0);
}
} // namespace bergamot
} // namespace marian

View File

@ -4,6 +4,7 @@
#include "common/options.h"
#include "data/corpus_base.h"
#include "definitions.h"
#include "pcqueue.h"
#include "request.h"
#include <set>
@ -19,6 +20,8 @@ public:
// sentence. This method inserts the sentence into the internal data-structure
// which maintains priority among sentences from multiple concurrent requests.
void addSentenceWithPriority(RequestSentence &sentence);
void addWholeRequest(Ptr<Request> request);
void enqueue(PCQueue<PCItem> &pcqueue);
// Loads sentences with sentences compiled from (tentatively) multiple
// requests optimizing for both padding and priority.
@ -27,6 +30,7 @@ public:
private:
unsigned int max_input_tokens_;
std::vector<std::set<RequestSentence>> bucket_;
unsigned int batchNumber_{0};
};
} // namespace bergamot

View File

@ -14,13 +14,17 @@ Service::Service(Ptr<Options> options)
text_processor_(vocabs_, options), batcher_(options),
pcqueue_(2 * options->get<int>("cpu-threads")) {
workers_.reserve(numWorkers_);
for (int cpuId = 0; cpuId < numWorkers_; cpuId++) {
workers_.emplace_back([&] {
marian::DeviceId deviceId(cpuId, DeviceType::cpu);
translation_loop(deviceId, pcqueue_, vocabs_, options);
});
if (numWorkers_ > 0) {
workers_.reserve(numWorkers_);
for (int cpuId = 0; cpuId < numWorkers_; cpuId++) {
workers_.emplace_back([&] {
marian::DeviceId deviceId(cpuId, DeviceType::cpu);
translation_loop(deviceId, pcqueue_, vocabs_, options);
});
}
} else {
marian::DeviceId deviceId(/*cpuId=*/0, DeviceType::cpu);
translator = new BatchTranslator(deviceId, vocabs_, options);
}
}
@ -53,27 +57,28 @@ std::future<TranslationResult> Service::translate(std::string &&input) {
std::move(segments), std::move(sourceAlignments),
std::move(translationResultPromise));
for (int i = 0; i < request->numSegments(); i++) {
RequestSentence requestSentence(i, request);
batcher_.addSentenceWithPriority(requestSentence);
batcher_.addWholeRequest(request);
if (numWorkers_ > 0) {
batcher_.enqueue(pcqueue_);
} else {
// Queue single-threaded
int numSentences;
do {
RequestSentences batchSentences;
batcher_.cleaveBatch(batchSentences);
numSentences = batchSentences.size();
if (numSentences > 0) {
translator->translate(batchSentences);
batchNumber_++;
}
if (batchNumber_ % 500 == 0) {
LOG(info, "Tranlsating batch {}", batchNumber_);
}
} while (numSentences > 0);
}
int numSentences;
do {
RequestSentences batchSentences;
batcher_.cleaveBatch(batchSentences);
numSentences = batchSentences.size();
if (numSentences > 0) {
PCItem pcitem(batchNumber_++, std::move(batchSentences));
pcqueue_.ProduceSwap(pcitem);
}
if (batchNumber_ % 500 == 0) {
LOG(info, "Queuing batch {}", batchNumber_);
}
} while (numSentences > 0);
return future;
}

View File

@ -70,6 +70,9 @@ private:
Batcher batcher_;
PCQueue<PCItem> pcqueue_;
std::vector<std::thread> workers_;
// Optional
BatchTranslator *translator{nullptr};
};
std::vector<Ptr<const Vocab>> loadVocabularies(Ptr<Options> options);