diff --git a/src/gpu/mblas/nth_element.cu b/src/gpu/mblas/nth_element.cu index 8cce38c0..8e13e828 100644 --- a/src/gpu/mblas/nth_element.cu +++ b/src/gpu/mblas/nth_element.cu @@ -254,7 +254,7 @@ __global__ void gGetValueByKey(float* d_in, float* d_out, int* indeces, int n) NthElement::NthElement(size_t maxBeamSize, size_t maxBatchSize, cudaStream_t& stream) : stream_(stream), - NUM_BLOCKS(std::max(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0))) + NUM_BLOCKS(std::min(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0))) { HANDLE_ERROR( cudaMalloc((void**)&d_ind, maxBatchSize * NUM_BLOCKS * sizeof(int)) );