move threadpool into God

This commit is contained in:
Hieu Hoang 2017-02-17 17:32:07 +00:00
parent 2f09087e0b
commit 60370ba504
3 changed files with 47 additions and 43 deletions

View File

@ -27,43 +27,25 @@ int main(int argc, char* argv[]) {
std::size_t lineNum = 0;
size_t maxiSize = god.Get<size_t>("maxi-batch");
//std::cerr << "mode=" << god.Get("mode") << std::endl;
size_t cpuThreads = god.Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
LOG(info) << "Reading input";
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = god.Get<size_t>("gpu-threads");
auto devices = god.Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
SentencesPtr maxiBatch(new Sentences());
LOG(info) << "Total number of threads: " << totalThreads;
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
{
ThreadPool pool(totalThreads, totalThreads);
LOG(info) << "Reading input";
SentencesPtr maxiBatch(new Sentences());
while (std::getline(god.GetInputStream(), in)) {
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
if (maxiBatch->size() >= maxiSize) {
god.Enqueue(*maxiBatch, pool);
maxiBatch.reset(new Sentences());
}
while (std::getline(god.GetInputStream(), in)) {
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
if (maxiBatch->size() >= maxiSize) {
god.Enqueue(*maxiBatch);
maxiBatch.reset(new Sentences());
}
god.Enqueue(*maxiBatch, pool);
}
// last batch
god.Enqueue(*maxiBatch);
LOG(info) << "Total time: " << timer.format();
god.CleanUp();
//sleep(10);
return 0;
}

View File

@ -30,6 +30,13 @@ God::God()
God::~God()
{
pool_.reset();
for (Loaders::value_type& loader : cpuLoaders_) {
loader.second.reset(nullptr);
}
for (Loaders::value_type& loader : gpuLoaders_) {
loader.second.reset(nullptr);
}
}
God& God::Init(const std::string& options) {
@ -86,6 +93,12 @@ God& God::Init(int argc, char** argv) {
LoadPrePostProcessing();
size_t totalThreads = GetTotalThreads();
LOG(info) << "Total number of threads: " << totalThreads;
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
pool_.reset(new ThreadPool(totalThreads, totalThreads));
return *this;
}
@ -236,15 +249,6 @@ std::vector<std::string> God::Postprocess(const std::vector<std::string>& input)
return processed;
}
void God::CleanUp() {
for (Loaders::value_type& loader : cpuLoaders_) {
loader.second.reset(nullptr);
}
for (Loaders::value_type& loader : gpuLoaders_) {
loader.second.reset(nullptr);
}
}
DeviceInfo God::GetNextDevice() const
{
DeviceInfo ret;
@ -283,19 +287,34 @@ Search &God::GetSearch() const
return obj;
}
void God::Enqueue(Sentences &maxiBatch, ThreadPool &pool)
void God::Enqueue(Sentences &maxiBatch)
{
size_t miniBatch = Get<size_t>("mini-batch");
size_t miniSize = Get<size_t>("mini-batch");
maxiBatch.SortByLength();
while (maxiBatch.size()) {
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniBatch);
pool.enqueue(
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniSize);
pool_->enqueue(
[this,miniBatch]{ return TranslationTask(*this, miniBatch); }
);
}
}
size_t God::GetTotalThreads() const
{
size_t cpuThreads = Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = Get<size_t>("gpu-threads");
auto devices = Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
return totalThreads;
}
}

View File

@ -66,14 +66,15 @@ class God {
std::vector<std::string> Preprocess(size_t i, const std::vector<std::string>& input) const;
std::vector<std::string> Postprocess(const std::vector<std::string>& input) const;
void CleanUp();
void LoadWeights(const std::string& path);
DeviceInfo GetNextDevice() const;
Search &GetSearch() const;
void Enqueue(Sentences &maxiBatch, ThreadPool &pool);
void Enqueue(Sentences &maxiBatch);
size_t GetTotalThreads() const;
private:
void LoadScorers();
@ -104,6 +105,8 @@ class God {
mutable size_t threadIncr_;
mutable boost::shared_mutex accessLock_;
std::unique_ptr<ThreadPool> pool_;
};
}