mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
add argmax()
This commit is contained in:
parent
948789787d
commit
40ac84ba73
@ -33,6 +33,10 @@ Expr softmax_fast(Expr a) {
|
||||
return Expr(a.graph(), new SoftmaxNodeOp(a));
|
||||
}
|
||||
|
||||
Expr argmax(Expr a) {
|
||||
return Expr(a.graph(), new ArgmaxNodeOp(a));
|
||||
}
|
||||
|
||||
/*********************************************************/
|
||||
|
||||
static Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
|
@ -79,6 +79,8 @@ Expr softmax(Expr a, Args ...args) {
|
||||
|
||||
Expr softmax_fast(Expr a);
|
||||
|
||||
Expr argmax(Expr a);
|
||||
|
||||
// inefficient
|
||||
template <typename ...Args>
|
||||
inline Expr mean(Expr a, Args ...args) {
|
||||
|
Loading…
Reference in New Issue
Block a user