mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge branch 'master' of github.com:emjotde/marian
This commit is contained in:
commit
70f8277eb1
@ -33,6 +33,10 @@ Expr softmax(Expr a) {
|
|||||||
return Expr(a.graph(), new SoftmaxNodeOp(a));
|
return Expr(a.graph(), new SoftmaxNodeOp(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Expr argmax(Expr a) {
|
||||||
|
return Expr(a.graph(), new ArgmaxNodeOp(a));
|
||||||
|
}
|
||||||
|
|
||||||
/*********************************************************/
|
/*********************************************************/
|
||||||
|
|
||||||
static Shape newShape(ChainPtr a, ChainPtr b) {
|
static Shape newShape(ChainPtr a, ChainPtr b) {
|
||||||
|
@ -79,6 +79,8 @@ Expr softmax_slow(Expr a, Args ...args) {
|
|||||||
|
|
||||||
Expr softmax(Expr a);
|
Expr softmax(Expr a);
|
||||||
|
|
||||||
|
Expr argmax(Expr a);
|
||||||
|
|
||||||
// inefficient
|
// inefficient
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
inline Expr mean(Expr a, Args ...args) {
|
inline Expr mean(Expr a, Args ...args) {
|
||||||
|
@ -185,6 +185,35 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
ArgmaxNodeOp(ChainPtr a, Args ...args)
|
||||||
|
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
// B = softmax(A).
|
||||||
|
Argmax(&val_, &a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape newShape(ChainPtr a) {
|
||||||
|
Shape shape = a->shape();
|
||||||
|
shape[1] = 1;
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
struct LogNodeOp : public UnaryNodeOp {
|
struct LogNodeOp : public UnaryNodeOp {
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
LogNodeOp(Args ...args)
|
LogNodeOp(Args ...args)
|
||||||
|
Loading…
Reference in New Issue
Block a user