mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
input & output are different
This commit is contained in:
parent
7803f44a97
commit
a4111bf1fe
33
src/test.cu
33
src/test.cu
@ -8,37 +8,20 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
///////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////
|
||||||
__global__ void gArgMax(float* arr, size_t rows, size_t cols) {
|
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
|
||||||
for (size_t row = 0; row < rows; ++row) {
|
|
||||||
size_t startInd = row * cols;
|
|
||||||
float maxScore = -99999;
|
|
||||||
size_t maxInd = -1;
|
|
||||||
for (size_t col = 0; col < cols; ++col) {
|
|
||||||
size_t ind = startInd + col;
|
|
||||||
float score = arr[ind];
|
|
||||||
if (score > maxScore) {
|
|
||||||
maxScore = score;
|
|
||||||
maxInd = col;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
arr[startInd] = maxInd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void gArgMax2(float* arr, size_t rows, size_t cols) {
|
|
||||||
size_t row = blockIdx.x;
|
size_t row = blockIdx.x;
|
||||||
size_t startInd = row * cols;
|
size_t startInd = row * cols;
|
||||||
float maxScore = -99999;
|
float maxScore = -99999;
|
||||||
size_t maxInd = -1;
|
size_t maxInd = -1;
|
||||||
for (size_t col = 0; col < cols; ++col) {
|
for (size_t col = 0; col < cols; ++col) {
|
||||||
size_t ind = startInd + col;
|
size_t ind = startInd + col;
|
||||||
float score = arr[ind];
|
float score = data[ind];
|
||||||
if (score > maxScore) {
|
if (score > maxScore) {
|
||||||
maxScore = score;
|
maxScore = score;
|
||||||
maxInd = col;
|
maxInd = col;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
arr[startInd] = maxInd;
|
out[row] = maxInd;
|
||||||
}
|
}
|
||||||
|
|
||||||
string output(const std::vector<float> &vec)
|
string output(const std::vector<float> &vec)
|
||||||
@ -62,13 +45,19 @@ void temp()
|
|||||||
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
||||||
float *data = thrust::raw_pointer_cast(dVec.data());
|
float *data = thrust::raw_pointer_cast(dVec.data());
|
||||||
|
|
||||||
//gArgMax<<<10, 20, sizeof(float)>>>(data, 4, 2);
|
thrust::device_vector<float> dLabel(4);
|
||||||
gArgMax2<<<4, 1, sizeof(float)>>>(data, 4, 2);
|
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
|
||||||
|
|
||||||
|
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
|
||||||
|
|
||||||
std::vector<float> hVec2(8);
|
std::vector<float> hVec2(8);
|
||||||
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
||||||
cerr << "hVec2=" << output(hVec2) << endl;
|
cerr << "hVec2=" << output(hVec2) << endl;
|
||||||
|
|
||||||
|
std::vector<float> hLabel(4);
|
||||||
|
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
|
||||||
|
cerr << "hLabel=" << output(hLabel) << endl;
|
||||||
|
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user