mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
input & output are different
This commit is contained in:
parent
7803f44a97
commit
a4111bf1fe
33
src/test.cu
33
src/test.cu
@ -8,37 +8,20 @@
|
||||
using namespace std;
|
||||
|
||||
///////////////////////////////////////////////////////
|
||||
__global__ void gArgMax(float* arr, size_t rows, size_t cols) {
|
||||
for (size_t row = 0; row < rows; ++row) {
|
||||
size_t startInd = row * cols;
|
||||
float maxScore = -99999;
|
||||
size_t maxInd = -1;
|
||||
for (size_t col = 0; col < cols; ++col) {
|
||||
size_t ind = startInd + col;
|
||||
float score = arr[ind];
|
||||
if (score > maxScore) {
|
||||
maxScore = score;
|
||||
maxInd = col;
|
||||
}
|
||||
}
|
||||
arr[startInd] = maxInd;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gArgMax2(float* arr, size_t rows, size_t cols) {
|
||||
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
|
||||
size_t row = blockIdx.x;
|
||||
size_t startInd = row * cols;
|
||||
float maxScore = -99999;
|
||||
size_t maxInd = -1;
|
||||
for (size_t col = 0; col < cols; ++col) {
|
||||
size_t ind = startInd + col;
|
||||
float score = arr[ind];
|
||||
float score = data[ind];
|
||||
if (score > maxScore) {
|
||||
maxScore = score;
|
||||
maxInd = col;
|
||||
}
|
||||
}
|
||||
arr[startInd] = maxInd;
|
||||
out[row] = maxInd;
|
||||
}
|
||||
|
||||
string output(const std::vector<float> &vec)
|
||||
@ -62,13 +45,19 @@ void temp()
|
||||
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
||||
float *data = thrust::raw_pointer_cast(dVec.data());
|
||||
|
||||
//gArgMax<<<10, 20, sizeof(float)>>>(data, 4, 2);
|
||||
gArgMax2<<<4, 1, sizeof(float)>>>(data, 4, 2);
|
||||
thrust::device_vector<float> dLabel(4);
|
||||
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
|
||||
|
||||
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
|
||||
|
||||
std::vector<float> hVec2(8);
|
||||
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
||||
cerr << "hVec2=" << output(hVec2) << endl;
|
||||
|
||||
std::vector<float> hLabel(4);
|
||||
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
|
||||
cerr << "hLabel=" << output(hLabel) << endl;
|
||||
|
||||
exit(0);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user