GetNextDevice()

This commit is contained in:
Hieu Hoang 2017-01-23 12:56:19 +00:00
parent e10d1c3d86
commit 1a9392b075
3 changed files with 49 additions and 0 deletions

View File

@ -16,6 +16,12 @@
#include "scorer.h"
#include "loader_factory.h"
God::God()
:threadIncr_(0)
{
}
God::~God() {}
God& God::Init(const std::string& options) {
@ -232,3 +238,29 @@ void God::CleanUp() {
loader.reset(nullptr);
}
}
DeviceInfo God::GetNextDevice() const
{
DeviceInfo ret;
size_t cpuThreads = God::Get<size_t>("cpu-threads");
ret.deviceType = (threadIncr_ < cpuThreads) ? CPUDevice : GPUDevice;
if (ret.deviceType == CPUDevice) {
ret.threadInd = threadIncr_;
}
else {
size_t threadIncrGPU = threadIncr_ - cpuThreads;
size_t gpuThreads = Get<size_t>("gpu-threads");
ret.threadInd = threadIncrGPU / gpuThreads;
ret.deviceInd = threadIncrGPU % gpuThreads;
std::vector<size_t> devices = Get<std::vector<size_t>>("devices");
UTIL_THROW_IF2(ret.threadInd >= gpuThreads, "Too many GPU threads");
UTIL_THROW_IF2(ret.deviceInd >= devices.size(), "Too many GPU devices");
}
return ret;
}

View File

@ -24,6 +24,7 @@ class InputFileStream;
class God {
public:
God();
virtual ~God();
God& Init(const std::string&);
@ -89,4 +90,8 @@ class God {
mutable std::unique_ptr<InputFileStream> inputStream_;
mutable OutputCollector outputCollector_;
mutable size_t threadIncr_;
DeviceInfo GetNextDevice() const;
};

View File

@ -10,3 +10,15 @@ typedef std::vector<Word> Words;
const Word EOS = 0;
const Word UNK = 1;
enum DeviceType
{
CPUDevice,
GPUDevice
};
struct DeviceInfo
{
DeviceType deviceType;
size_t threadInd;
size_t deviceInd;
};