diff --git a/src/node_operators.h b/src/node_operators.h index e7994c0a..c444f24f 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -156,6 +156,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { void forward() { // B = softmax(A). val_ = a_->val(); + SubtractMax(&val_); // Safe version of softmax. Softmax(&val_); } diff --git a/src/optimizers.h b/src/optimizers.h index a977d7f8..184b063f 100644 --- a/src/optimizers.h +++ b/src/optimizers.h @@ -95,4 +95,4 @@ class Adam { std::vector vt_; }; -} \ No newline at end of file +} diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index aa92f0dd..34ab874a 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -55,6 +55,56 @@ void SubtractMean(Tensor* Out, Tensor &Weights) { cudaStreamSynchronize(0); } +__global__ void gSubtractMax(float* out, size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if (j < rows) { + extern __shared__ float _share[]; + float* _max = _share + blockDim.x; + float* sp = out + j * cols; + _max[threadIdx.x] = sp[threadIdx.x]; + for(int tid = 1; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if (id < cols) { + if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id]; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if (threadIdx.x < (len >> 1)) { + if (_max[threadIdx.x + skip] > _max[threadIdx.x]) { + _max[threadIdx.x] = _max[threadIdx.x + skip]; + } + } + len = (len + 1) >> 1; + } + __syncthreads(); + for(int tid = 0; tid < cols; tid += blockDim.x){ + int id = tid + threadIdx.x; + if(id < cols) + sp[id] -= _max[0]; + } + } + } +} + +void SubtractMax(Tensor* Out) { + // Out is a m-by-k matrix, passed as input. + // The max element of each row of Out is computed and subtracted from Out. + // Out is both input and output. + size_t m = Out->shape()[0]; + size_t k = Out->shape()[1]; + + int blocks = std::min(MAX_BLOCKS, (int) m); + int threads = std::min(MAX_THREADS, (int) k); + int shared = sizeof(float) * threads * 2; + gSubtractMax<<>>(Out->data(), m, k); + cudaStreamSynchronize(0); +} + /////////////////////////////////////////////////////// __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { diff --git a/src/tensor_operators.h b/src/tensor_operators.h index 60e989d2..039e6f39 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -147,6 +147,10 @@ __global__ void gSubtractMean(float* out, float* weights, void SubtractMean(Tensor* Out, Tensor &Weights); +__global__ void gSubtractMax(float* out, size_t rows, size_t cols); + +void SubtractMax(Tensor* Out); + __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols); void Softmax(Tensor* Out);