From ed684e25b6f532819e001467d711ba40d67914bd Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sun, 18 Sep 2016 16:21:15 +0200 Subject: [PATCH] combined elementwise functions for softmax gradient into single kernel call, no temporary objects needed, slightly faster --- src/chainable.h | 1 + src/mnist_benchmark.cu | 1 - src/node_operators.h | 13 +++++----- src/tensor_operators.cu | 41 ++++++++++++++++--------------- src/tensor_operators.h | 11 +-------- src/test.cu | 54 ++++++++++++++++++++--------------------- 6 files changed, 57 insertions(+), 64 deletions(-) diff --git a/src/chainable.h b/src/chainable.h index a1683966..fa4d1452 100644 --- a/src/chainable.h +++ b/src/chainable.h @@ -13,6 +13,7 @@ struct Chainable { virtual ~Chainable() { } virtual void forward() { } virtual void backward() { } + virtual void check() { } virtual void init_dependent() { } virtual void set_zero_adjoint() { } diff --git a/src/mnist_benchmark.cu b/src/mnist_benchmark.cu index 5f0bc705..f0d1e9f9 100644 --- a/src/mnist_benchmark.cu +++ b/src/mnist_benchmark.cu @@ -40,7 +40,6 @@ ExpressionGraph build_graph(const std::vector& dims) { biases.emplace_back( g.param(shape={1, out}, init=normal())); - } auto probs = named( diff --git a/src/node_operators.h b/src/node_operators.h index db2031e9..7d66b703 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -92,8 +92,7 @@ struct UnaryNodeOp : public Node { template UnaryNodeOp(ChainPtr a, Args ...args) : Node(keywords::shape=a->shape(), //@TODO: Check keywords? - args...), - a_(a) {} + args...), a_(a) {} }; struct LogitNodeOp : public UnaryNodeOp { @@ -111,6 +110,10 @@ struct LogitNodeOp : public UnaryNodeOp { a_->grad(), adj_, val_); } + void check() { + + } + virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; @@ -171,10 +174,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { // Classification." ICML 2016. // http://jmlr.org/proceedings/papers/v48/martins16.pdf - Tensor result(adj_.shape()); - thrust::copy(adj_.begin(), adj_.end(), result.begin()); - SubtractMean(&result, val_); - Element(_1 += _2 * _3, a_->grad(), val_, result); + SoftmaxGrad(a_->grad(), adj_, val_); } virtual std::string graphviz() { @@ -183,7 +183,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp { ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; - }; struct ArgmaxNodeOp : public UnaryNodeOp { diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index ad30d051..5f5b8cc8 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -12,20 +12,22 @@ static cublasHandle_t create_handle() { } cublasHandle_t cublasHandle = create_handle(); -__global__ void gSubtractMean(float* out, float* weights, - size_t rows, size_t cols) { +__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val, + const int rows, const int cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { extern __shared__ float _share[]; float* _sum = _share + blockDim.x; - float* sp = out + j * cols; - float* w = weights + j * cols; + + float* gradRow = grad + j * cols; + const float* adjRow = adj + j * cols; + const float* valRow = val + j * cols; _sum[threadIdx.x] = 0.0; for(int tid = 0; tid < cols; tid += blockDim.x) { int id = tid + threadIdx.x; if(id < cols) { - _sum[threadIdx.x] += w[id] * sp[id]; + _sum[threadIdx.x] += valRow[id] * adjRow[id]; } } __syncthreads(); @@ -41,25 +43,25 @@ __global__ void gSubtractMean(float* out, float* weights, for(int tid = 0; tid < cols; tid += blockDim.x){ int id = tid + threadIdx.x; if(id < cols) - sp[id] -= _sum[0]; + gradRow[id] += valRow[id] * (adjRow[id] - _sum[0]); } } } } -void SubtractMean(Tensor* Out, Tensor &Weights) { - // Out and Weights are both m-by-k matrices, passed as input. - // A weighted average of each row of Out (according to the weights - // specified in Weights) is computed and subtracted from Out. - // Out is both input and output. - size_t m = Out->shape()[0]; - size_t k = Out->shape()[1]; +void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { + // grad and val are both m-by-k matrices, passed as input. + // A weighted average of each row of grad (according to the weights + // specified in val) is computed and subtracted from Out. + // adj is multiplied for each element to get backward step in autodiff + int m = grad.shape()[0]; + int k = grad.shape()[1]; - int blocks = std::min(MAX_BLOCKS, (int) m); - int threads = std::min(MAX_THREADS, (int) k); + int blocks = std::min(MAX_BLOCKS, m); + int threads = std::min(MAX_THREADS, k); int shared = sizeof(float) * threads * 2; - gSubtractMean<<>>(Out->data(), Weights.data(), - m, k); + gSoftmaxGrad<<>>(grad.data(), adj.data(), val.data(), + m, k); cudaStreamSynchronize(0); } @@ -158,8 +160,9 @@ void Softmax(Tensor* Out) { gSoftMax<<>>(Out->data(), m, k); cudaStreamSynchronize(0); } + /////////////////////////////////////////////////////// -__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) { +__global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols) { size_t row = blockIdx.x; size_t startInd = row * cols; float maxScore = -99999; @@ -182,7 +185,7 @@ void Argmax(Tensor* Out, const Tensor* In) { int blocks = m; //std::min(MAX_BLOCKS, (int) m); int threads = k; //std::min(MAX_THREADS, (int) k); //int shared = sizeof(float) * threads * 2; - gArgMax<<>>(Out->data(), In->data(), m, k); + gArgmax<<>>(Out->data(), In->data(), m, k); cudaStreamSynchronize(0); } diff --git a/src/tensor_operators.h b/src/tensor_operators.h index 039e6f39..06118e8f 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -142,20 +142,11 @@ void Element(Functor functor, cudaStreamSynchronize(0); } -__global__ void gSubtractMean(float* out, float* weights, - size_t rows, size_t cols); - -void SubtractMean(Tensor* Out, Tensor &Weights); - -__global__ void gSubtractMax(float* out, size_t rows, size_t cols); - void SubtractMax(Tensor* Out); -__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols); - void Softmax(Tensor* Out); -__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols); +void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val); void Argmax(Tensor* Out, const Tensor* In); diff --git a/src/test.cu b/src/test.cu index 0f6a6334..5aa7de8f 100644 --- a/src/test.cu +++ b/src/test.cu @@ -17,33 +17,33 @@ string output(const std::vector &vec) return strm.str(); } -void testArgMax() -{ - using namespace std; - using namespace marian; - - std::vector hVec({29,19, 49,39, 79,99, 79,39}); - cerr << "hVec =" << output(hVec) << endl; - - thrust::device_vector dVec(8); - thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); - float *data = thrust::raw_pointer_cast(dVec.data()); - - thrust::device_vector dLabel(4); - float *labelPtr = thrust::raw_pointer_cast(dLabel.data()); - - gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); - - std::vector hVec2(8); - thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); - cerr << "hVec2=" << output(hVec2) << endl; - - std::vector hLabel(4); - thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin()); - cerr << "hLabel=" << output(hLabel) << endl; - - exit(0); -} +//void testArgMax() +//{ +// using namespace std; +// using namespace marian; +// +// std::vector hVec({29,19, 49,39, 79,99, 79,39}); +// cerr << "hVec =" << output(hVec) << endl; +// +// thrust::device_vector dVec(8); +// thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); +// float *data = thrust::raw_pointer_cast(dVec.data()); +// +// thrust::device_vector dLabel(4); +// float *labelPtr = thrust::raw_pointer_cast(dLabel.data()); +// +// gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); +// +// std::vector hVec2(8); +// thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); +// cerr << "hVec2=" << output(hVec2) << endl; +// +// std::vector hLabel(4); +// thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin()); +// cerr << "hLabel=" << output(hLabel) << endl; +// +// exit(0); +//} /////////////////////////////////////////////////////// int main(int argc, char** argv) {