From b43747ab5cc42568d33c2f300daea93d2c26750a Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 18 Apr 2016 21:16:07 +0200 Subject: [PATCH] organize lm queries in to batch --- src/decoder/kenlm.cpp | 3 +++ src/decoder/kenlm.h | 1 + src/decoder/search.h | 35 ++++++++++++++++++++++------------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/decoder/kenlm.cpp b/src/decoder/kenlm.cpp index 203ece1f..11c88ab8 100644 --- a/src/decoder/kenlm.cpp +++ b/src/decoder/kenlm.cpp @@ -102,3 +102,6 @@ float LM::GetWeight() const { return weight_; } +size_t LM::size() const { + return vm_.size(); +} diff --git a/src/decoder/kenlm.h b/src/decoder/kenlm.h index ae352c88..15ce84cd 100644 --- a/src/decoder/kenlm.h +++ b/src/decoder/kenlm.h @@ -55,6 +55,7 @@ class LM { WordPairs::const_iterator end() const; size_t GetIndex() const; float GetWeight() const; + size_t size() const; private: std::unique_ptr lm_; diff --git a/src/decoder/search.h b/src/decoder/search.h index a5288fe0..be81d5e2 100644 --- a/src/decoder/search.h +++ b/src/decoder/search.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -137,23 +138,31 @@ class Search { std::vector costs(rows * cols); states.resize(rows * cols); - { - ThreadPool pool(4); - for(size_t i = 0; i < prevHyps.size(); i++) { - auto call = [i, cols, &prevHyps, &lm, &costs, &states] { - const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()]; - KenlmState stateUnk; - float costUnk = lm.Score(state, 0, stateUnk); - std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk); - std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk); - for(auto& wp : lm) { - costs[i * cols + wp.second] = lm.Score(state, wp.first, states[i * cols + wp.second]); - } + for(size_t i = 0; i < prevHyps.size(); i++) { + const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()]; + KenlmState stateUnk; + float costUnk = lm.Score(state, 0, stateUnk); + std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk); + std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk); + } + + { + ThreadPool pool(8); + size_t batchSize = 1000; + for(size_t batchStart = 0; batchStart < lm.size(); batchStart += batchSize) { + auto call = [batchStart, batchSize, cols, &prevHyps, &lm, &costs, &states] { + size_t batchEnd = min(batchStart + batchSize, lm.size()); + for(auto it = lm.begin() + batchStart; it != lm.begin() + batchEnd; ++it) + for(size_t i = 0; i < prevHyps.size(); i++) { + const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()]; + costs[i * cols + it->second] = lm.Score(state, it->first, states[i * cols + it->second]); + } }; pool.enqueue(call); } } - cudaSetDevice(device_); + + cudaSetDevice(device_); // ??? thrust::copy(costs.begin(), costs.end(), LmProbs.begin()); }