Defer model loading to parallel worker thread (#303)

This commit is contained in:
Jelmer 2022-01-14 10:30:38 +00:00 committed by GitHub
parent 71b84b7c72
commit 13c55e2693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -44,10 +44,6 @@ TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory
srcIdx, trgIdx, shared_vcb);
}
}
for (size_t idx = 0; idx < replicas; idx++) {
loadBackend(idx);
}
}
void TranslationModel::loadBackend(size_t idx) {
@ -172,6 +168,12 @@ Ptr<marian::data::CorpusBatch> TranslationModel::convertToMarianBatch(Batch &bat
void TranslationModel::translateBatch(size_t deviceId, Batch &batch) {
auto &backend = backend_[deviceId];
if (!backend.initialized) {
loadBackend(deviceId);
backend.initialized = true;
}
BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target());
Histories histories = search.search(backend.graph, convertToMarianBatch(batch));
batch.completeBatch(histories);

View File

@ -107,6 +107,7 @@ class TranslationModel {
Graph graph;
ScorerEnsemble scorerEnsemble;
bool initialized{false};
};
// ShortlistGenerator is purely const, we don't need one per thread.