From 948789787d80282bd0e4579e3a09726b23fd134b Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 16 Sep 2016 19:49:59 +0200 Subject: [PATCH 1/2] add ArgmaxNodeOp --- src/node_operators.h | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/node_operators.h b/src/node_operators.h index 8620a645..1de5e37d 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -178,6 +178,35 @@ struct SoftmaxNodeOp : public UnaryNodeOp { }; +struct ArgmaxNodeOp : public UnaryNodeOp { + template + ArgmaxNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } + + void forward() { + // B = softmax(A). + Argmax(&val_, &a_->val()); + } + + void backward() { + } + + Shape newShape(ChainPtr a) { + Shape shape = a->shape(); + shape[1] = 1; + return shape; + } + + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + +}; + struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args) From 40ac84ba73f20997b962220563ab53793268a0cd Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 16 Sep 2016 19:56:01 +0200 Subject: [PATCH 2/2] add argmax() --- src/expression_operators.cu | 4 ++++ src/expression_operators.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/expression_operators.cu b/src/expression_operators.cu index 73c23f1d..9f648768 100644 --- a/src/expression_operators.cu +++ b/src/expression_operators.cu @@ -33,6 +33,10 @@ Expr softmax_fast(Expr a) { return Expr(a.graph(), new SoftmaxNodeOp(a)); } +Expr argmax(Expr a) { + return Expr(a.graph(), new ArgmaxNodeOp(a)); +} + /*********************************************************/ static Shape newShape(ChainPtr a, ChainPtr b) { diff --git a/src/expression_operators.h b/src/expression_operators.h index 4cb69dbb..2072ddba 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -79,6 +79,8 @@ Expr softmax(Expr a, Args ...args) { Expr softmax_fast(Expr a); +Expr argmax(Expr a); + // inefficient template inline Expr mean(Expr a, Args ...args) {