mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
organize lm queries in to batch
This commit is contained in:
parent
349a0266cf
commit
b43747ab5c
@ -102,3 +102,6 @@ float LM::GetWeight() const {
|
||||
return weight_;
|
||||
}
|
||||
|
||||
size_t LM::size() const {
|
||||
return vm_.size();
|
||||
}
|
||||
|
@ -55,6 +55,7 @@ class LM {
|
||||
WordPairs::const_iterator end() const;
|
||||
size_t GetIndex() const;
|
||||
float GetWeight() const;
|
||||
size_t size() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<KenlmModel> lm_;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <set>
|
||||
#include <boost/timer/timer.hpp>
|
||||
#include <thread>
|
||||
#include <algorithm>
|
||||
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/device_vector.h>
|
||||
@ -137,23 +138,31 @@ class Search {
|
||||
std::vector<float> costs(rows * cols);
|
||||
states.resize(rows * cols);
|
||||
|
||||
{
|
||||
ThreadPool pool(4);
|
||||
for(size_t i = 0; i < prevHyps.size(); i++) {
|
||||
auto call = [i, cols, &prevHyps, &lm, &costs, &states] {
|
||||
const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
|
||||
KenlmState stateUnk;
|
||||
float costUnk = lm.Score(state, 0, stateUnk);
|
||||
std::fill(costs.begin() + i * cols, costs.begin() + i * cols + cols, costUnk);
|
||||
std::fill(states.begin() + i * cols, states.begin() + i * cols + cols, stateUnk);
|
||||
for(auto& wp : lm) {
|
||||
costs[i * cols + wp.second] = lm.Score(state, wp.first, states[i * cols + wp.second]);
|
||||
}
|
||||
|
||||
{
|
||||
ThreadPool pool(8);
|
||||
size_t batchSize = 1000;
|
||||
for(size_t batchStart = 0; batchStart < lm.size(); batchStart += batchSize) {
|
||||
auto call = [batchStart, batchSize, cols, &prevHyps, &lm, &costs, &states] {
|
||||
size_t batchEnd = min(batchStart + batchSize, lm.size());
|
||||
for(auto it = lm.begin() + batchStart; it != lm.begin() + batchEnd; ++it)
|
||||
for(size_t i = 0; i < prevHyps.size(); i++) {
|
||||
const KenlmState state = prevHyps[i]->GetLMStates()[lm.GetIndex()];
|
||||
costs[i * cols + it->second] = lm.Score(state, it->first, states[i * cols + it->second]);
|
||||
}
|
||||
};
|
||||
pool.enqueue(call);
|
||||
}
|
||||
}
|
||||
cudaSetDevice(device_);
|
||||
|
||||
cudaSetDevice(device_); // ???
|
||||
thrust::copy(costs.begin(), costs.end(), LmProbs.begin());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user