From 40ac84ba73f20997b962220563ab53793268a0cd Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 16 Sep 2016 19:56:01 +0200 Subject: [PATCH] add argmax() --- src/expression_operators.cu | 4 ++++ src/expression_operators.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/expression_operators.cu b/src/expression_operators.cu index 73c23f1d..9f648768 100644 --- a/src/expression_operators.cu +++ b/src/expression_operators.cu @@ -33,6 +33,10 @@ Expr softmax_fast(Expr a) { return Expr(a.graph(), new SoftmaxNodeOp(a)); } +Expr argmax(Expr a) { + return Expr(a.graph(), new ArgmaxNodeOp(a)); +} + /*********************************************************/ static Shape newShape(ChainPtr a, ChainPtr b) { diff --git a/src/expression_operators.h b/src/expression_operators.h index 4cb69dbb..2072ddba 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -79,6 +79,8 @@ Expr softmax(Expr a, Args ...args) { Expr softmax_fast(Expr a); +Expr argmax(Expr a); + // inefficient template inline Expr mean(Expr a, Args ...args) {