mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
move gArgMax to tensor_operators. Write ArgMax() with tensors
This commit is contained in:
parent
a4111bf1fe
commit
5173e9e550
@ -1,5 +1,7 @@
|
|||||||
#include "tensor_operators.h"
|
#include "tensor_operators.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
__global__ void gSubtractMean(float* out, float* weights,
|
__global__ void gSubtractMean(float* out, float* weights,
|
||||||
@ -53,6 +55,7 @@ void SubtractMean(Tensor* Out, Tensor &Weights) {
|
|||||||
cudaStreamSynchronize(0);
|
cudaStreamSynchronize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////
|
||||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
int j = bid + blockIdx.x;
|
int j = bid + blockIdx.x;
|
||||||
@ -97,6 +100,35 @@ void Softmax(Tensor* Out) {
|
|||||||
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
||||||
cudaStreamSynchronize(0);
|
cudaStreamSynchronize(0);
|
||||||
}
|
}
|
||||||
|
///////////////////////////////////////////////////////
|
||||||
|
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
|
||||||
|
size_t row = blockIdx.x;
|
||||||
|
size_t startInd = row * cols;
|
||||||
|
float maxScore = -99999;
|
||||||
|
size_t maxInd;
|
||||||
|
for (size_t col = 0; col < cols; ++col) {
|
||||||
|
size_t ind = startInd + col;
|
||||||
|
float score = data[ind];
|
||||||
|
if (score > maxScore) {
|
||||||
|
maxScore = score;
|
||||||
|
maxInd = col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out[row] = maxInd;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Argmax(Tensor* Out, const Tensor* In) {
|
||||||
|
size_t m = In->shape()[0];
|
||||||
|
size_t k = In->shape()[1];
|
||||||
|
|
||||||
|
int blocks = m; //std::min(MAX_BLOCKS, (int) m);
|
||||||
|
int threads = k; //std::min(MAX_THREADS, (int) k);
|
||||||
|
//int shared = sizeof(float) * threads * 2;
|
||||||
|
gArgMax<<<blocks, threads>>>(Out->data(), In->data(), m, k);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////
|
||||||
|
|
||||||
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||||
bool transA, bool transB, Float beta) {
|
bool transA, bool transB, Float beta) {
|
||||||
@ -137,4 +169,4 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
|
|||||||
return temp;
|
return temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -151,6 +151,10 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
|||||||
|
|
||||||
void Softmax(Tensor* Out);
|
void Softmax(Tensor* Out);
|
||||||
|
|
||||||
|
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols);
|
||||||
|
|
||||||
|
void Argmax(Tensor* Out, const Tensor* In);
|
||||||
|
|
||||||
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||||
bool transA, bool transB, Float beta);
|
bool transA, bool transB, Float beta);
|
||||||
|
|
||||||
|
34
src/test.cu
34
src/test.cu
@ -8,22 +8,6 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
///////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////
|
||||||
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
|
|
||||||
size_t row = blockIdx.x;
|
|
||||||
size_t startInd = row * cols;
|
|
||||||
float maxScore = -99999;
|
|
||||||
size_t maxInd = -1;
|
|
||||||
for (size_t col = 0; col < cols; ++col) {
|
|
||||||
size_t ind = startInd + col;
|
|
||||||
float score = data[ind];
|
|
||||||
if (score > maxScore) {
|
|
||||||
maxScore = score;
|
|
||||||
maxInd = col;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out[row] = maxInd;
|
|
||||||
}
|
|
||||||
|
|
||||||
string output(const std::vector<float> &vec)
|
string output(const std::vector<float> &vec)
|
||||||
{
|
{
|
||||||
stringstream strm;
|
stringstream strm;
|
||||||
@ -38,27 +22,27 @@ void temp()
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace marian;
|
using namespace marian;
|
||||||
|
|
||||||
std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
|
std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
|
||||||
cerr << "hVec =" << output(hVec) << endl;
|
cerr << "hVec =" << output(hVec) << endl;
|
||||||
|
|
||||||
thrust::device_vector<float> dVec(8);
|
thrust::device_vector<float> dVec(8);
|
||||||
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
||||||
float *data = thrust::raw_pointer_cast(dVec.data());
|
float *data = thrust::raw_pointer_cast(dVec.data());
|
||||||
|
|
||||||
thrust::device_vector<float> dLabel(4);
|
thrust::device_vector<float> dLabel(4);
|
||||||
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
|
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
|
||||||
|
|
||||||
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
|
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
|
||||||
|
|
||||||
std::vector<float> hVec2(8);
|
std::vector<float> hVec2(8);
|
||||||
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
||||||
cerr << "hVec2=" << output(hVec2) << endl;
|
cerr << "hVec2=" << output(hVec2) << endl;
|
||||||
|
|
||||||
std::vector<float> hLabel(4);
|
std::vector<float> hLabel(4);
|
||||||
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
|
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
|
||||||
cerr << "hLabel=" << output(hLabel) << endl;
|
cerr << "hLabel=" << output(hLabel) << endl;
|
||||||
|
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user