mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
move threadpool into God
This commit is contained in:
parent
2f09087e0b
commit
60370ba504
@ -27,43 +27,25 @@ int main(int argc, char* argv[]) {
|
||||
std::size_t lineNum = 0;
|
||||
|
||||
size_t maxiSize = god.Get<size_t>("maxi-batch");
|
||||
//std::cerr << "mode=" << god.Get("mode") << std::endl;
|
||||
|
||||
size_t cpuThreads = god.Get<size_t>("cpu-threads");
|
||||
LOG(info) << "Setting CPU thread count to " << cpuThreads;
|
||||
LOG(info) << "Reading input";
|
||||
|
||||
size_t totalThreads = cpuThreads;
|
||||
#ifdef CUDA
|
||||
size_t gpuThreads = god.Get<size_t>("gpu-threads");
|
||||
auto devices = god.Get<std::vector<size_t>>("devices");
|
||||
LOG(info) << "Setting GPU thread count to " << gpuThreads;
|
||||
totalThreads += gpuThreads * devices.size();
|
||||
#endif
|
||||
SentencesPtr maxiBatch(new Sentences());
|
||||
|
||||
LOG(info) << "Total number of threads: " << totalThreads;
|
||||
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
|
||||
|
||||
{
|
||||
ThreadPool pool(totalThreads, totalThreads);
|
||||
LOG(info) << "Reading input";
|
||||
|
||||
SentencesPtr maxiBatch(new Sentences());
|
||||
|
||||
while (std::getline(god.GetInputStream(), in)) {
|
||||
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
|
||||
|
||||
if (maxiBatch->size() >= maxiSize) {
|
||||
god.Enqueue(*maxiBatch, pool);
|
||||
maxiBatch.reset(new Sentences());
|
||||
}
|
||||
while (std::getline(god.GetInputStream(), in)) {
|
||||
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
|
||||
|
||||
if (maxiBatch->size() >= maxiSize) {
|
||||
god.Enqueue(*maxiBatch);
|
||||
maxiBatch.reset(new Sentences());
|
||||
}
|
||||
|
||||
god.Enqueue(*maxiBatch, pool);
|
||||
}
|
||||
|
||||
// last batch
|
||||
god.Enqueue(*maxiBatch);
|
||||
|
||||
LOG(info) << "Total time: " << timer.format();
|
||||
god.CleanUp();
|
||||
//sleep(10);
|
||||
return 0;
|
||||
}
|
||||
|
@ -30,6 +30,13 @@ God::God()
|
||||
|
||||
God::~God()
|
||||
{
|
||||
pool_.reset();
|
||||
for (Loaders::value_type& loader : cpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
for (Loaders::value_type& loader : gpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
God& God::Init(const std::string& options) {
|
||||
@ -86,6 +93,12 @@ God& God::Init(int argc, char** argv) {
|
||||
|
||||
LoadPrePostProcessing();
|
||||
|
||||
size_t totalThreads = GetTotalThreads();
|
||||
LOG(info) << "Total number of threads: " << totalThreads;
|
||||
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
|
||||
|
||||
pool_.reset(new ThreadPool(totalThreads, totalThreads));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -236,15 +249,6 @@ std::vector<std::string> God::Postprocess(const std::vector<std::string>& input)
|
||||
return processed;
|
||||
}
|
||||
|
||||
void God::CleanUp() {
|
||||
for (Loaders::value_type& loader : cpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
for (Loaders::value_type& loader : gpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceInfo God::GetNextDevice() const
|
||||
{
|
||||
DeviceInfo ret;
|
||||
@ -283,19 +287,34 @@ Search &God::GetSearch() const
|
||||
return obj;
|
||||
}
|
||||
|
||||
void God::Enqueue(Sentences &maxiBatch, ThreadPool &pool)
|
||||
void God::Enqueue(Sentences &maxiBatch)
|
||||
{
|
||||
size_t miniBatch = Get<size_t>("mini-batch");
|
||||
size_t miniSize = Get<size_t>("mini-batch");
|
||||
|
||||
maxiBatch.SortByLength();
|
||||
while (maxiBatch.size()) {
|
||||
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniBatch);
|
||||
pool.enqueue(
|
||||
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniSize);
|
||||
pool_->enqueue(
|
||||
[this,miniBatch]{ return TranslationTask(*this, miniBatch); }
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
size_t God::GetTotalThreads() const
|
||||
{
|
||||
size_t cpuThreads = Get<size_t>("cpu-threads");
|
||||
LOG(info) << "Setting CPU thread count to " << cpuThreads;
|
||||
|
||||
size_t totalThreads = cpuThreads;
|
||||
#ifdef CUDA
|
||||
size_t gpuThreads = Get<size_t>("gpu-threads");
|
||||
auto devices = Get<std::vector<size_t>>("devices");
|
||||
LOG(info) << "Setting GPU thread count to " << gpuThreads;
|
||||
totalThreads += gpuThreads * devices.size();
|
||||
#endif
|
||||
return totalThreads;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -66,14 +66,15 @@ class God {
|
||||
std::vector<std::string> Preprocess(size_t i, const std::vector<std::string>& input) const;
|
||||
std::vector<std::string> Postprocess(const std::vector<std::string>& input) const;
|
||||
|
||||
void CleanUp();
|
||||
|
||||
void LoadWeights(const std::string& path);
|
||||
|
||||
DeviceInfo GetNextDevice() const;
|
||||
Search &GetSearch() const;
|
||||
|
||||
void Enqueue(Sentences &maxiBatch, ThreadPool &pool);
|
||||
void Enqueue(Sentences &maxiBatch);
|
||||
|
||||
size_t GetTotalThreads() const;
|
||||
|
||||
private:
|
||||
void LoadScorers();
|
||||
@ -104,6 +105,8 @@ class God {
|
||||
|
||||
mutable size_t threadIncr_;
|
||||
mutable boost::shared_mutex accessLock_;
|
||||
|
||||
std::unique_ptr<ThreadPool> pool_;
|
||||
};
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user