diff --git a/src/expression_operators.cu b/src/expression_operators.cu index 73c23f1d..c4b5109d 100644 --- a/src/expression_operators.cu +++ b/src/expression_operators.cu @@ -29,7 +29,7 @@ Expr operator-(Expr a) { return Expr(a.graph(), new NegNodeOp(a)); }; -Expr softmax_fast(Expr a) { +Expr softmax(Expr a) { return Expr(a.graph(), new SoftmaxNodeOp(a)); } diff --git a/src/expression_operators.h b/src/expression_operators.h index 4cb69dbb..3207c5ed 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -72,12 +72,12 @@ inline Expr sum(Expr a, Args ...args) { // inefficient template -Expr softmax(Expr a, Args ...args) { +Expr softmax_slow(Expr a, Args ...args) { Expr e = exp(a); return e / sum(e, args...); } -Expr softmax_fast(Expr a); +Expr softmax(Expr a); // inefficient template diff --git a/src/sgd.h b/src/sgd.h index 85a4e4af..fe0470b1 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -14,7 +14,8 @@ class Sgd { graph.backprop(batchSize); for(auto& param : graph.params()) - Element(_1 -= eta_ * _2, param.val(), param.grad()); + Element(_1 -= eta_ * _2, + param.val(), param.grad()); } private: @@ -36,7 +37,8 @@ class Adagrad { auto it = history_.begin(); for(auto& param : graph.params()) { Element(_1 += _2 * _2, *it, param.grad()); - Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad()); + Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, + param.val(), *it, param.grad()); it++; } } diff --git a/src/test.cu b/src/test.cu index 1201ad99..8c7dfc54 100644 --- a/src/test.cu +++ b/src/test.cu @@ -100,10 +100,10 @@ int main(int argc, char** argv) { std::cerr << "Building output layer..." << std::endl; std::vector Yp; - Yp.emplace_back(softmax_fast(dot(H[0], Why) + by)); + Yp.emplace_back(softmax(dot(H[0], Why) + by)); Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1); for (int t = 1; t < num_inputs; ++t) { - Yp.emplace_back(softmax_fast(dot(H[t], Why) + by)); + Yp.emplace_back(softmax(dot(H[t], Why) + by)); cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1); } auto graph = -mean(cross_entropy, axis=0, name="cost"); diff --git a/src/train_mnist.cu b/src/train_mnist.cu index 2dda8fde..e726ee83 100644 --- a/src/train_mnist.cu +++ b/src/train_mnist.cu @@ -25,7 +25,7 @@ int main(int argc, char** argv) { Expr b = named(g.param(shape={1, LABEL_SIZE}), "b"); auto scores = dot(x, w) + b; - auto lr = softmax_fast(scores); + auto lr = softmax(scores); auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost"); cerr << "lr=" << lr.Debug() << endl; diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu index d1f54bde..c98e6245 100644 --- a/src/validate_encoder_decoder.cu +++ b/src/validate_encoder_decoder.cu @@ -80,10 +80,10 @@ ExpressionGraph build_graph() { // Softmax layer and cost function. std::vector Yp; - Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred")); + Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred")); Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1); for (int t = 1; t <= num_outputs; ++t) { - Yp.emplace_back(named(softmax_fast(dot(S[t-1], Why) + by), "pred")); + Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred")); cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1); } auto cost = named(-mean(cross_entropy, axis=0), "cost"); diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index f9bc0dcf..690d6f40 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -32,7 +32,7 @@ ExpressionGraph build_graph() { init=from_vector(bData)), "b"); auto probs = named( - softmax_fast(dot(x, w) + b), //, axis=1), + softmax(dot(x, w) + b), //, axis=1), "probs" ); diff --git a/src/validate_mnist_batch.cu b/src/validate_mnist_batch.cu index d37e9ca3..76379eda 100644 --- a/src/validate_mnist_batch.cu +++ b/src/validate_mnist_batch.cu @@ -68,7 +68,7 @@ int main(int argc, char** argv) { std::cerr << "Building model..."; auto layer1 = tanh(dot(x, w1) + b1); - auto layer2 = softmax(dot(layer1, w2) + b2, axis=1, name="layer2"); + auto layer2 = softmax(dot(layer1, w2) + b2); auto predict = layer2; std::cerr << "Done." << std::endl;