diff --git a/src/expression_operators.h b/src/expression_operators.h index 3c3dc031..8eabbd04 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -134,6 +134,7 @@ inline Expr sum(Expr a, Args ...args) { else if(ax == 1) { auto lshape = [n]() -> Shape { int cols = n->val().shape()[1]; + //std::cerr << "Shape will be " << cols << " by 1." << std::endl; return {cols, 1}; }; Expr one = ones(shape={n->shape()[1], 1}, @@ -153,6 +154,20 @@ inline Expr sum(Expr a, Args ...args) { template inline Expr softmax(Expr a, Args ...args) { Expr e = exp(a); +#if 0 + ChainPtr n = a.node(); + auto print_shape = [n]() -> Shape { + std::cerr << "Shape: "; + for (auto val : n->val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + return {1,1}; + }; + using namespace keywords; + Expr one = ones(shape={1, 1}, lazy_shape=print_shape); +#endif + return e / sum(e, args...); } @@ -187,4 +202,4 @@ inline Expr mean(Expr a, Args ...args) { } } -} \ No newline at end of file +} diff --git a/src/graph_operators.h b/src/graph_operators.h index d07c4b38..30456153 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -118,9 +118,15 @@ struct LogNodeOp : public UnaryNodeOp { struct ExpNodeOp : public UnaryNodeOp { template - ExpNodeOp(Args ...args) - : UnaryNodeOp(args...) { } + ExpNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a), + args...) { } + Shape newShape(ChainPtr a) { + Shape shape = a->shape(); + return shape; + } + void forward() { Element(_1 = Exp(_2), val_, a_->val()); } @@ -289,4 +295,4 @@ struct DivNodeOp : public BroadcastingNodeOp { } }; -} \ No newline at end of file +} diff --git a/src/test.cu b/src/test.cu index 56156fee..48b996f8 100644 --- a/src/test.cu +++ b/src/test.cu @@ -12,7 +12,8 @@ int main(int argc, char** argv) { auto w = param(shape={784, 10}, name="W0"); auto b = param(shape={1, 10}, name="b0"); - auto lr = softmax(dot(x, w) + b, axis=1, name="pred"); + auto scores = dot(x, w) + b; + auto lr = softmax(scores, axis=1, name="pred"); auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); Tensor tx({500, 784}, 1); @@ -22,28 +23,44 @@ int main(int argc, char** argv) { y = ty; graph.forward(500); + std::cerr << "Result: "; + for (auto val : scores.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + std::cerr << "Result: "; + for (auto val : lr.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + std::cerr << "Log-likelihood: "; + for (auto val : graph.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + //std::cerr << graph["pred"].val()[0] << std::endl; - - //hook0(graph); - //graph.autodiff(); - //std::cerr << graph["cost"].val()[0] << std::endl; +#if 0 + hook0(graph); + graph.autodiff(); + std::cerr << graph["cost"].val()[0] << std::endl; //hook1(graph); - //for(auto p : graph.params()) { - // auto update = _1 = _1 - alpha * _2; - // Element(update, p.val(), p.grad()); - //} - //hook2(graph); - // - //auto opt = adadelta(cost_function=cost, - // eta=0.9, gamma=0.1, - // set_batch=set, - // before_update=before, - // after_update=after, - // set_valid=valid, - // validation_freq=100, - // verbose=1, epochs=3, early_stopping=10); - //opt.run(); + for(auto p : graph.params()) { + auto update = _1 = _1 - alpha * _2; + Element(update, p.val(), p.grad()); + } + hook2(graph); + auto opt = adadelta(cost_function=cost, + eta=0.9, gamma=0.1, + set_batch=set, + before_update=before, + after_update=after, + set_valid=valid, + validation_freq=100, + verbose=1, epochs=3, early_stopping=10); + opt.run(); +#endif return 0; } \ No newline at end of file