From 03809913932ccd148e62261741163597ad72fe70 Mon Sep 17 00:00:00 2001 From: Andre Martins Date: Wed, 14 Sep 2016 08:36:01 +0100 Subject: [PATCH] Implemented fast softmax. --- src/expression_operators.h | 7 ++++++ src/graph_operators.h | 24 +++++++++++++++++++ src/tensor.h | 7 ++++++ src/tensor_operators.cu | 47 ++++++++++++++++++++++++++++++++++++++ src/tensor_operators.h | 4 ++++ src/test.cu | 6 +++-- 6 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/expression_operators.h b/src/expression_operators.h index 8eabbd04..3d42400f 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -171,6 +171,13 @@ inline Expr softmax(Expr a, Args ...args) { return e / sum(e, args...); } +template +inline Expr softmax_fast(Expr a, Args ...args) { + Expr e = Expr(new SoftmaxNodeOp(a, args...)); + return e; +} + + // inefficient template inline Expr mean(Expr a, Args ...args) { diff --git a/src/graph_operators.h b/src/graph_operators.h index 30456153..5a12f807 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -101,6 +101,30 @@ struct TanhNodeOp : public UnaryNodeOp { } }; +struct SoftmaxNodeOp : public UnaryNodeOp { + template + SoftmaxNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a), + args...) { } + + Shape newShape(ChainPtr a) { + Shape shape = a->shape(); + return shape; + } + + void forward() { + // B = softmax(A). + val_ = a_->val(); + Softmax(&val_); + } + + void backward() { + // TODO + Element(_1 += _2 * Exp(_3), + a_->grad(), adj_, a_->val()); + } +}; + struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args) diff --git a/src/tensor.h b/src/tensor.h index 487a553a..bf6b8ef8 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -240,6 +240,13 @@ class Tensor { return pimpl_->Debug(); } + void Print() const { + for (int i = 0; i < size(); ++i) { + std::cerr << (*this)[i] << " "; + } + std::cerr << std::endl; + } + }; } diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index a8f72893..2d1d541d 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -2,6 +2,53 @@ namespace marian { +// TODO: implement this. +__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + extern __shared__ float _share[]; + float* _sum = _share + blockDim.x; + float* sp = softMaxP + j * cols; + _sum[threadIdx.x] = 0.0; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + sp[id] = __expf(sp[id]); + _sum[threadIdx.x] += sp[id]; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sum[threadIdx.x] += _sum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + for(int tid = 0; tid < cols; tid += blockDim.x){ + int id = tid + threadIdx.x; + if(id < cols) + sp[id] /= _sum[0]; + } + } + } +} + +// TODO: implement this. +void Softmax(Tensor* Out) { + size_t m = Out->shape()[0]; + size_t k = Out->shape()[1]; + + int blocks = std::min(MAX_BLOCKS, (int) m); + int threads = std::min(MAX_THREADS, (int) k); + int shared = sizeof(float) * threads * 2; + gSoftMax<<>>(Out->data(), m, k); + cudaStreamSynchronize(0); +} + Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta) { Float alpha = 1.0; diff --git a/src/tensor_operators.h b/src/tensor_operators.h index 7ec4ca68..a0c30104 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -142,6 +142,10 @@ void Element(Functor functor, cudaStreamSynchronize(0); } +__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols); + +void Softmax(Tensor* Out); + Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta); diff --git a/src/test.cu b/src/test.cu index 4a2445fd..1948b74f 100644 --- a/src/test.cu +++ b/src/test.cu @@ -15,7 +15,7 @@ int main(int argc, char** argv) { Expr b = param(shape={1, 10}, name="b0"); auto scores = dot(x, w) + b; - auto lr = softmax(scores, axis=1, name="pred"); + auto lr = softmax_fast(scores, axis=1, name="pred"); auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); cerr << "lr=" << lr.Debug() << endl; @@ -40,12 +40,14 @@ int main(int argc, char** argv) { std::cerr << val << " "; } std::cerr << std::endl; + lr.val().Print(); std::cerr << "Log-likelihood: "; for (auto val : graph.val().shape()) { std::cerr << val << " "; } std::cerr << std::endl; - + graph.val().Print(); + graph.backward(); //std::cerr << graph["pred"].val()[0] << std::endl;