diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index e9e09ee6..aa92f0dd 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -1,5 +1,7 @@ #include "tensor_operators.h" +using namespace std; + namespace marian { __global__ void gSubtractMean(float* out, float* weights, @@ -53,6 +55,7 @@ void SubtractMean(Tensor* Out, Tensor &Weights) { cudaStreamSynchronize(0); } +/////////////////////////////////////////////////////// __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; @@ -97,6 +100,35 @@ void Softmax(Tensor* Out) { gSoftMax<<>>(Out->data(), m, k); cudaStreamSynchronize(0); } +/////////////////////////////////////////////////////// +__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) { + size_t row = blockIdx.x; + size_t startInd = row * cols; + float maxScore = -99999; + size_t maxInd; + for (size_t col = 0; col < cols; ++col) { + size_t ind = startInd + col; + float score = data[ind]; + if (score > maxScore) { + maxScore = score; + maxInd = col; + } + } + out[row] = maxInd; +} + +void Argmax(Tensor* Out, const Tensor* In) { + size_t m = In->shape()[0]; + size_t k = In->shape()[1]; + + int blocks = m; //std::min(MAX_BLOCKS, (int) m); + int threads = k; //std::min(MAX_THREADS, (int) k); + //int shared = sizeof(float) * threads * 2; + gArgMax<<>>(Out->data(), In->data(), m, k); + cudaStreamSynchronize(0); +} + +/////////////////////////////////////////////////////// Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta) { @@ -137,4 +169,4 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B, return temp; } -} \ No newline at end of file +} diff --git a/src/tensor_operators.h b/src/tensor_operators.h index 03d754e3..60e989d2 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -151,6 +151,10 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols); void Softmax(Tensor* Out); +__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols); + +void Argmax(Tensor* Out, const Tensor* In); + Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta); diff --git a/src/test.cu b/src/test.cu index 25ec7b5d..5b128976 100644 --- a/src/test.cu +++ b/src/test.cu @@ -8,22 +8,6 @@ using namespace std; /////////////////////////////////////////////////////// -__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) { - size_t row = blockIdx.x; - size_t startInd = row * cols; - float maxScore = -99999; - size_t maxInd = -1; - for (size_t col = 0; col < cols; ++col) { - size_t ind = startInd + col; - float score = data[ind]; - if (score > maxScore) { - maxScore = score; - maxInd = col; - } - } - out[row] = maxInd; -} - string output(const std::vector &vec) { stringstream strm; @@ -38,27 +22,27 @@ void temp() using namespace std; using namespace marian; - std::vector hVec({29,19, 49,39, 79,99, 79,39}); + std::vector hVec({29,19, 49,39, 79,99, 79,39}); cerr << "hVec =" << output(hVec) << endl; - thrust::device_vector dVec(8); - thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); - float *data = thrust::raw_pointer_cast(dVec.data()); + thrust::device_vector dVec(8); + thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); + float *data = thrust::raw_pointer_cast(dVec.data()); thrust::device_vector dLabel(4); float *labelPtr = thrust::raw_pointer_cast(dLabel.data()); - gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); + gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); - std::vector hVec2(8); - thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); - cerr << "hVec2=" << output(hVec2) << endl; + std::vector hVec2(8); + thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); + cerr << "hVec2=" << output(hVec2) << endl; std::vector hLabel(4); thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin()); cerr << "hLabel=" << output(hLabel) << endl; - exit(0); + exit(0); } ///////////////////////////////////////////////////////