diff --git a/src/test.cu b/src/test.cu index da19db0f..25ec7b5d 100644 --- a/src/test.cu +++ b/src/test.cu @@ -8,37 +8,20 @@ using namespace std; /////////////////////////////////////////////////////// -__global__ void gArgMax(float* arr, size_t rows, size_t cols) { - for (size_t row = 0; row < rows; ++row) { - size_t startInd = row * cols; - float maxScore = -99999; - size_t maxInd = -1; - for (size_t col = 0; col < cols; ++col) { - size_t ind = startInd + col; - float score = arr[ind]; - if (score > maxScore) { - maxScore = score; - maxInd = col; - } - } - arr[startInd] = maxInd; - } -} - -__global__ void gArgMax2(float* arr, size_t rows, size_t cols) { +__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) { size_t row = blockIdx.x; size_t startInd = row * cols; float maxScore = -99999; size_t maxInd = -1; for (size_t col = 0; col < cols; ++col) { size_t ind = startInd + col; - float score = arr[ind]; + float score = data[ind]; if (score > maxScore) { maxScore = score; maxInd = col; } } - arr[startInd] = maxInd; + out[row] = maxInd; } string output(const std::vector &vec) @@ -62,13 +45,19 @@ void temp() thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); float *data = thrust::raw_pointer_cast(dVec.data()); - //gArgMax<<<10, 20, sizeof(float)>>>(data, 4, 2); - gArgMax2<<<4, 1, sizeof(float)>>>(data, 4, 2); + thrust::device_vector dLabel(4); + float *labelPtr = thrust::raw_pointer_cast(dLabel.data()); + + gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); std::vector hVec2(8); thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); cerr << "hVec2=" << output(hVec2) << endl; + std::vector hLabel(4); + thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin()); + cerr << "hLabel=" << output(hLabel) << endl; + exit(0); }