diff --git a/src/graph_operators.h b/src/graph_operators.h index 5a12f807..eb30ff29 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -101,6 +101,39 @@ struct TanhNodeOp : public UnaryNodeOp { } }; +struct ArgmaxOp : public UnaryNodeOp { + template + ArgmaxOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...), + axis_(-1) { } + + Shape newShape(ChainPtr a, int axis) { + Shape shape1 = a->shape(); + UTIL_THROW_IF2(shape1.size() > 2, + "Tensors with more than 2 dimensions not supported yet"); + if(axis == 0) { + shape1[0] = 1; + } + else if(axis == 1) { + shape1[1] = 1; + } + else { + shape1 = {1, 1}; + } + return shape1; + } + + void forward() { + //val_ = Argmax(a_->val(), axis_); + } + + void backward() {} + + private: + int axis_; +}; + + struct SoftmaxNodeOp : public UnaryNodeOp { template SoftmaxNodeOp(ChainPtr a, Args ...args) diff --git a/src/test.cu b/src/test.cu index a71939b4..0e9f9752 100644 --- a/src/test.cu +++ b/src/test.cu @@ -21,10 +21,14 @@ int main(int argc, char** argv) { Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); Expr b = param(shape={1, LABEL_SIZE}, name="b0"); - auto scores = dot(x, w) + b; - auto lr = softmax_fast(scores, axis=1, name="pred"); - auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - cerr << "lr=" << lr.Debug() << endl; + auto z = dot(x, w) + b; + auto pred = softmax(z); + //auto decision = argmax(pred, axis=1); + + auto cost = -mean(sum(y * log(pred), axis=1), + axis=0); + + cerr << "pred=" << pred.Debug() << endl; #if 0 int numofdata; @@ -49,27 +53,27 @@ int main(int argc, char** argv) { x = tx; y = ty; - graph.forward(500); + cost.forward(500); std::cerr << "Result: "; - for (auto val : scores.val().shape()) { + for (auto val : pred.val().shape()) { std::cerr << val << " "; } std::cerr << std::endl; std::cerr << "Result: "; - for (auto val : lr.val().shape()) { + for (auto val : pred.val().shape()) { std::cerr << val << " "; } std::cerr << std::endl; - lr.val().Print(); + pred.val().Print(); std::cerr << "Log-likelihood: "; - for (auto val : graph.val().shape()) { + for (auto val : cost.val().shape()) { std::cerr << val << " "; } std::cerr << std::endl; - graph.val().Print(); + cost.val().Print(); - graph.backward(); + cost.backward(); //std::cerr << graph["pred"].val()[0] << std::endl;