mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
try to use parallelism
This commit is contained in:
parent
269b94cea6
commit
4ea1156e99
19
src/test.cu
19
src/test.cu
@ -25,6 +25,22 @@ __global__ void gArgMax(float* arr, size_t rows, size_t cols) {
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gArgMax2(float* arr, size_t rows, size_t cols) {
|
||||
size_t row = blockIdx.x;
|
||||
size_t startInd = row * cols;
|
||||
float maxScore = -99999;
|
||||
size_t maxInd = -1;
|
||||
for (size_t col = 0; col < cols; ++col) {
|
||||
size_t ind = startInd + col;
|
||||
float score = arr[ind];
|
||||
if (score > maxScore) {
|
||||
maxScore = score;
|
||||
maxInd = col;
|
||||
}
|
||||
}
|
||||
arr[startInd] = maxInd;
|
||||
}
|
||||
|
||||
string output(const std::vector<float> &vec)
|
||||
{
|
||||
stringstream strm;
|
||||
@ -46,7 +62,8 @@ void temp()
|
||||
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
||||
float *data = thrust::raw_pointer_cast(dVec.data());
|
||||
|
||||
gArgMax<<<10, 20, sizeof(float)>>>(data, 4, 2);
|
||||
//gArgMax<<<10, 20, sizeof(float)>>>(data, 4, 2);
|
||||
gArgMax2<<<10, 20, sizeof(float)>>>(data, 4, 2);
|
||||
|
||||
std::vector<float> hVec2(8);
|
||||
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
||||
|
Loading…
Reference in New Issue
Block a user