organize lm queries in to batch

This commit is contained in:
Marcin Junczys-Dowmunt 2016-04-18 21:16:07 +02:00
parent 349a0266cf
commit b43747ab5c
3 changed files with 26 additions and 13 deletions

View File

@ -102,3 +102,6 @@ float LM::GetWeight() const {
return weight_;
}
size_t LM::size() const {
return vm_.size();
}

View File

@ -55,6 +55,7 @@ class LM {
WordPairs::const_iterator end() const;
size_t GetIndex() const;
float GetWeight() const;
size_t size() const;
private:
std::unique_ptr<KenlmModel> lm_;

View File

@ -10,6 +10,7 @@
#include <set>
#include <boost/timer/timer.hpp>
#include <thread>
#include <algorithm>
#include <thrust/functional.h>
#include <thrust/device_vector.h>
@ -137,23 +138,31 @@ class Search {
std::vector<float> costs(rows * cols);
states.resize(rows * cols);
{
ThreadPool pool(4);
for(size_t i = 0; i < prevHyps.size(); i++) {
auto call = [i, cols, &prevHyps, &lm, &costs, &states] {
const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
KenlmState stateUnk;
float costUnk = lm.Score(state, 0, stateUnk);
std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk);
std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk);
for(auto& wp : lm) {
costs[i * cols + wp.second] = lm.Score(state, wp.first, states[i * cols + wp.second]);
}
for(size_t i = 0; i < prevHyps.size(); i++) {
const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
KenlmState stateUnk;
float costUnk = lm.Score(state, 0, stateUnk);
std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk);
std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk);
}
{
ThreadPool pool(8);
size_t batchSize = 1000;
for(size_t batchStart = 0; batchStart < lm.size(); batchStart += batchSize) {
auto call = [batchStart, batchSize, cols, &prevHyps, &lm, &costs, &states] {
size_t batchEnd = min(batchStart + batchSize, lm.size());
for(auto it = lm.begin() + batchStart; it != lm.begin() + batchEnd; ++it)
for(size_t i = 0; i < prevHyps.size(); i++) {
const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
costs[i * cols + it->second] = lm.Score(state, it->first, states[i * cols + it->second]);
}
};
pool.enqueue(call);
}
}
cudaSetDevice(device_);
cudaSetDevice(device_); // ???
thrust::copy(costs.begin(), costs.end(), LmProbs.begin());
}