add argmax()

This commit is contained in:
Hieu Hoang 2016-09-16 19:56:01 +02:00
parent 948789787d
commit 40ac84ba73
2 changed files with 6 additions and 0 deletions

View File

@ -33,6 +33,10 @@ Expr softmax_fast(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}
Expr argmax(Expr a) {
return Expr(a.graph(), new ArgmaxNodeOp(a));
}
/*********************************************************/
static Shape newShape(ChainPtr a, ChainPtr b) {

View File

@ -79,6 +79,8 @@ Expr softmax(Expr a, Args ...args) {
Expr softmax_fast(Expr a);
Expr argmax(Expr a);
// inefficient
template <typename ...Args>
inline Expr mean(Expr a, Args ...args) {