mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
add ArgmaxNodeOp
This commit is contained in:
parent
c824de0e82
commit
948789787d
@ -178,6 +178,35 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
Argmax(&val_, &a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
}
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
shape[1] = 1;
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
|
Loading…
Reference in New Issue
Block a user