mirror of
https://github.com/marian-nmt/marian.git
synced 2025-01-07 17:10:15 +03:00
GetNextDevice()
This commit is contained in:
parent
e10d1c3d86
commit
1a9392b075
@ -16,6 +16,12 @@
|
||||
#include "scorer.h"
|
||||
#include "loader_factory.h"
|
||||
|
||||
God::God()
|
||||
:threadIncr_(0)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
God::~God() {}
|
||||
|
||||
God& God::Init(const std::string& options) {
|
||||
@ -232,3 +238,29 @@ void God::CleanUp() {
|
||||
loader.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceInfo God::GetNextDevice() const
|
||||
{
|
||||
DeviceInfo ret;
|
||||
|
||||
size_t cpuThreads = God::Get<size_t>("cpu-threads");
|
||||
ret.deviceType = (threadIncr_ < cpuThreads) ? CPUDevice : GPUDevice;
|
||||
|
||||
if (ret.deviceType == CPUDevice) {
|
||||
ret.threadInd = threadIncr_;
|
||||
}
|
||||
else {
|
||||
size_t threadIncrGPU = threadIncr_ - cpuThreads;
|
||||
size_t gpuThreads = Get<size_t>("gpu-threads");
|
||||
|
||||
ret.threadInd = threadIncrGPU / gpuThreads;
|
||||
ret.deviceInd = threadIncrGPU % gpuThreads;
|
||||
|
||||
std::vector<size_t> devices = Get<std::vector<size_t>>("devices");
|
||||
UTIL_THROW_IF2(ret.threadInd >= gpuThreads, "Too many GPU threads");
|
||||
UTIL_THROW_IF2(ret.deviceInd >= devices.size(), "Too many GPU devices");
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,7 @@ class InputFileStream;
|
||||
|
||||
class God {
|
||||
public:
|
||||
God();
|
||||
virtual ~God();
|
||||
|
||||
God& Init(const std::string&);
|
||||
@ -89,4 +90,8 @@ class God {
|
||||
mutable std::unique_ptr<InputFileStream> inputStream_;
|
||||
mutable OutputCollector outputCollector_;
|
||||
|
||||
mutable size_t threadIncr_;
|
||||
|
||||
DeviceInfo GetNextDevice() const;
|
||||
|
||||
};
|
||||
|
@ -10,3 +10,15 @@ typedef std::vector<Word> Words;
|
||||
const Word EOS = 0;
|
||||
const Word UNK = 1;
|
||||
|
||||
enum DeviceType
|
||||
{
|
||||
CPUDevice,
|
||||
GPUDevice
|
||||
};
|
||||
|
||||
struct DeviceInfo
|
||||
{
|
||||
DeviceType deviceType;
|
||||
size_t threadInd;
|
||||
size_t deviceInd;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user