mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
Fix again nth_element
This commit is contained in:
parent
eda37f3548
commit
359ec7bd98
@ -254,7 +254,7 @@ __global__ void gGetValueByKey(float* d_in, float* d_out, int* indeces, int n)
|
||||
|
||||
NthElement::NthElement(size_t maxBeamSize, size_t maxBatchSize, cudaStream_t& stream)
|
||||
: stream_(stream),
|
||||
NUM_BLOCKS(std::max(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0)))
|
||||
NUM_BLOCKS(std::min(500, int(maxBeamSize * 85000 / (2 * BLOCK_SIZE)) + int(maxBeamSize * 85000 % (2 * BLOCK_SIZE) != 0)))
|
||||
{
|
||||
HANDLE_ERROR( cudaMalloc((void**)&d_ind, maxBatchSize * NUM_BLOCKS * sizeof(int)) );
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user