Fix again nth_element

This commit is contained in:
Tomasz Dwojak 2017-03-09 11:54:33 +00:00
parent eda37f3548
commit 359ec7bd98

View File

@ -254,7 +254,7 @@ __global__ void gGetValueByKey(float* d_in, float* d_out, int* indeces, int n)
NthElement::NthElement(size_t maxBeamSize, size_t maxBatchSize, cudaStream_t& stream)
: stream_(stream),
NUM_BLOCKS(std::max(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0)))
NUM_BLOCKS(std::min(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0)))
{
HANDLE_ERROR( cudaMalloc((void**)&d_ind, maxBatchSize * NUM_BLOCKS * sizeof(int)) );