add argmax()

This commit is contained in:
Hieu Hoang 2016-09-16 19:56:01 +02:00
parent 948789787d
commit 40ac84ba73
2 changed files with 6 additions and 0 deletions

View File

@ -33,6 +33,10 @@ Expr softmax_fast(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a)); return Expr(a.graph(), new SoftmaxNodeOp(a));
} }
Expr argmax(Expr a) {
return Expr(a.graph(), new ArgmaxNodeOp(a));
}
/*********************************************************/ /*********************************************************/
static Shape newShape(ChainPtr a, ChainPtr b) { static Shape newShape(ChainPtr a, ChainPtr b) {

View File

@ -79,6 +79,8 @@ Expr softmax(Expr a, Args ...args) {
Expr softmax_fast(Expr a); Expr softmax_fast(Expr a);
Expr argmax(Expr a);
// inefficient // inefficient
template <typename ...Args> template <typename ...Args>
inline Expr mean(Expr a, Args ...args) { inline Expr mean(Expr a, Args ...args) {