mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
parent
20363f1a00
commit
011b7b4289
@ -17,15 +17,8 @@
|
||||
#include "scorer.h"
|
||||
#include "loader_factory.h"
|
||||
|
||||
God::God()
|
||||
:numGPUThreads_(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
God::~God() {}
|
||||
|
||||
|
||||
God& God::Init(const std::string& options) {
|
||||
std::vector<std::string> args = boost::program_options::split_unix(options);
|
||||
int argc = args.size() + 1;
|
||||
@ -246,15 +239,8 @@ Search &God::GetSearch(size_t taskCounter)
|
||||
Search *obj;
|
||||
obj = search_.get();
|
||||
if (obj == NULL) {
|
||||
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
|
||||
|
||||
size_t maxGPUThreads = God::Get<size_t>("gpu-threads");
|
||||
DeviceType deviceType = (numGPUThreads_ < maxGPUThreads) ? GPUDevice : CPUDevice;
|
||||
++numGPUThreads_;
|
||||
|
||||
obj = new Search(*this, deviceType, taskCounter);
|
||||
obj = new Search(*this, taskCounter);
|
||||
search_.reset(obj);
|
||||
|
||||
}
|
||||
assert(obj);
|
||||
return *obj;
|
||||
|
@ -2,7 +2,6 @@
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include <boost/thread/tss.hpp>
|
||||
#include <boost/thread/shared_mutex.hpp>
|
||||
|
||||
#include "common/processor/processor.h"
|
||||
#include "common/config.h"
|
||||
@ -27,7 +26,6 @@ class Search;
|
||||
|
||||
class God {
|
||||
public:
|
||||
God();
|
||||
virtual ~God();
|
||||
|
||||
God& Init(const std::string&);
|
||||
@ -95,8 +93,5 @@ class God {
|
||||
std::unique_ptr<InputFileStream> inputStream_;
|
||||
OutputCollector outputCollector_;
|
||||
|
||||
mutable boost::shared_mutex m_accessLock;
|
||||
mutable boost::thread_specific_ptr<Search> search_;
|
||||
size_t numGPUThreads_;
|
||||
|
||||
};
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
Search::Search(God &god, DeviceType deviceType, size_t threadId)
|
||||
Search::Search(God &god, size_t threadId)
|
||||
: scorers_(god.GetScorers(threadId)),
|
||||
bestHyps_(god.GetBestHyps(threadId)) {
|
||||
}
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
class Search {
|
||||
public:
|
||||
Search(God &god, DeviceType deviceType, size_t threadId);
|
||||
Search(God &god, size_t threadId);
|
||||
std::shared_ptr<Histories> Decode(God &god, const Sentences& sentences);
|
||||
|
||||
private:
|
||||
|
@ -10,8 +10,3 @@ typedef std::vector<Word> Words;
|
||||
const Word EOS = 0;
|
||||
const Word UNK = 1;
|
||||
|
||||
enum DeviceType
|
||||
{
|
||||
CPUDevice,
|
||||
GPUDevice
|
||||
};
|
||||
|
@ -16,11 +16,16 @@
|
||||
God *god_;
|
||||
|
||||
std::shared_ptr<Histories> TranslationTask(const std::string& in, size_t taskCounter) {
|
||||
Search &search = god_->GetSearch(taskCounter);
|
||||
thread_local std::unique_ptr<Search> search;
|
||||
|
||||
if(!search) {
|
||||
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
|
||||
search.reset(new Search(*god_, taskCounter));
|
||||
}
|
||||
|
||||
std::shared_ptr<Sentences> sentences(new Sentences());
|
||||
sentences->push_back(SentencePtr(new Sentence(*god_, taskCounter, in)));
|
||||
return search.Decode(*god_, *sentences);
|
||||
return search->Decode(*god_, *sentences);
|
||||
}
|
||||
|
||||
void init(const std::string& options) {
|
||||
|
Loading…
Reference in New Issue
Block a user