This commit is contained in:
Hieu Hoang 2017-01-25 15:33:38 +00:00
parent 4f8a0401d4
commit 47f01d3f44
4 changed files with 20 additions and 6 deletions

View File

@ -29,10 +29,11 @@ add_library(libcommon OBJECT
common/logging.cpp
common/output_collector.cpp
common/printer.cpp
common/processor/bpe.cpp
common/scorer.cpp
common/search.cpp
common/sentence.cpp
common/processor/bpe.cpp
common/types.cpp
common/utils.cpp
common/vocab.cpp
common/translation_task.cpp

View File

@ -17,6 +17,8 @@
#include "scorer.h"
#include "loader_factory.h"
using namespace std;
God::God()
:threadIncr_(0)
{
@ -252,10 +254,10 @@ DeviceInfo God::GetNextDevice() const
size_t gpuThreads = Get<size_t>("gpu-threads");
std::vector<size_t> devices = Get<std::vector<size_t>>("devices");
ret.threadInd = threadIncrGPU / gpuThreads;
ret.threadInd = threadIncrGPU / devices.size();
ret.deviceInd = threadIncrGPU % gpuThreads;
UTIL_THROW_IF2(ret.deviceInd >= devices.size(), "Too many GPU devices");
ret.deviceInd = threadIncrGPU % devices.size();
assert(ret.deviceInd < devices.size());
ret.deviceInd = devices[ret.deviceInd];
UTIL_THROW_IF2(ret.threadInd >= gpuThreads, "Too many GPU threads");
@ -263,6 +265,7 @@ DeviceInfo God::GetNextDevice() const
++threadIncr_;
cerr << "GetNextDevice=" << ret << endl;
return ret;
}

7
src/common/types.cpp Normal file
View File

@ -0,0 +1,7 @@
#include "types.h"
std::ostream& operator<<(std::ostream& out, const DeviceInfo& obj)
{
out << obj.deviceType << " t=" << obj.threadInd << " d=" << obj.deviceInd;
return out;
}

View File

@ -3,6 +3,7 @@
#include <cstdlib>
#include <cstdint>
#include <vector>
#include <iostream>
typedef size_t Word;
typedef std::vector<Word> Words;
@ -12,12 +13,14 @@ const Word UNK = 1;
enum DeviceType
{
CPUDevice,
GPUDevice
CPUDevice = 7,
GPUDevice = 11
};
struct DeviceInfo
{
friend std::ostream& operator<<(std::ostream& out, const DeviceInfo& obj);
DeviceType deviceType;
size_t threadInd;
size_t deviceInd;