diff --git a/src/expression_operators.cu b/src/expression_operators.cu index c4b5109d..59c1c52d 100644 --- a/src/expression_operators.cu +++ b/src/expression_operators.cu @@ -33,6 +33,10 @@ Expr softmax(Expr a) { return Expr(a.graph(), new SoftmaxNodeOp(a)); } +Expr argmax(Expr a) { + return Expr(a.graph(), new ArgmaxNodeOp(a)); +} + /*********************************************************/ static Shape newShape(ChainPtr a, ChainPtr b) { diff --git a/src/expression_operators.h b/src/expression_operators.h index 3207c5ed..6a9b4e53 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -79,6 +79,8 @@ Expr softmax_slow(Expr a, Args ...args) { Expr softmax(Expr a); +Expr argmax(Expr a); + // inefficient template inline Expr mean(Expr a, Args ...args) { diff --git a/src/node_operators.h b/src/node_operators.h index c63c9333..e7994c0a 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -185,6 +185,35 @@ struct SoftmaxNodeOp : public UnaryNodeOp { }; +struct ArgmaxNodeOp : public UnaryNodeOp { + template + ArgmaxNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } + + void forward() { + // B = softmax(A). + Argmax(&val_, &a_->val()); + } + + void backward() { + } + + Shape newShape(ChainPtr a) { + Shape shape = a->shape(); + shape[1] = 1; + return shape; + } + + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + +}; + struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args)