From 6d3f67e9555d76c7926728b8aae07e535faeb028 Mon Sep 17 00:00:00 2001 From: Andre Martins Date: Thu, 15 Sep 2016 13:54:46 +0100 Subject: [PATCH] Fixed backward for fast softmax. --- src/expression_operators.h | 14 -------------- src/graph_operators.h | 7 +++---- src/test.cu | 2 +- 3 files changed, 4 insertions(+), 19 deletions(-) diff --git a/src/expression_operators.h b/src/expression_operators.h index 253047d3..957ceed1 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -153,20 +153,6 @@ inline Expr sum(Expr a, Args ...args) { template inline Expr softmax(Expr a, Args ...args) { Expr e = exp(a); -#if 0 - ChainPtr n = a.node(); - auto print_shape = [n]() -> Shape { - std::cerr << "Shape: "; - for (auto val : n->val().shape()) { - std::cerr << val << " "; - } - std::cerr << std::endl; - return {1,1}; - }; - using namespace keywords; - Expr one = ones(shape={1, 1}, lazy_shape=print_shape); -#endif - return e / sum(e, args...); } diff --git a/src/graph_operators.h b/src/graph_operators.h index c7c0a057..a6320201 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -162,11 +162,10 @@ struct SoftmaxNodeOp : public UnaryNodeOp { // For each row, the Jacobian times vector is given by: // J * dy = p .* (dy - avg*1) // where avg = p'*dy and p is the softmax output (probabilities). - Tensor result = adj_; + Tensor result(adj_.shape()); + thrust::copy(adj_.begin(), adj_.end(), result.begin()); SubtractMean(&result, val_); - // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C - // to sum gradients from different graph parts. - Prod(a_->grad(), adj_, result, false, false, 1.0); + Element(_1 += _2 * _3, a_->grad(), val_, result); } }; diff --git a/src/test.cu b/src/test.cu index 629c1bc2..85636dd3 100644 --- a/src/test.cu +++ b/src/test.cu @@ -19,7 +19,7 @@ int main(int argc, char** argv) { //Expr b = param(shape={1, 2}, name="b0"); std::cerr << "Building model..."; - auto predict = softmax(dot(x, w), + auto predict = softmax_fast(dot(x, w), axis=1, name="pred"); auto graph = -mean(sum(y * log(predict), axis=1), axis=0, name="cost");