Merge branch 'master' of github.com:emjotde/marian

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 23:27:05 +02:00
commit 70f8277eb1
3 changed files with 35 additions and 0 deletions

View File

@ -33,6 +33,10 @@ Expr softmax(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}
Expr argmax(Expr a) {
return Expr(a.graph(), new ArgmaxNodeOp(a));
}
/*********************************************************/
static Shape newShape(ChainPtr a, ChainPtr b) {

View File

@ -79,6 +79,8 @@ Expr softmax_slow(Expr a, Args ...args) {
Expr softmax(Expr a);
Expr argmax(Expr a);
// inefficient
template <typename ...Args>
inline Expr mean(Expr a, Args ...args) {

View File

@ -185,6 +185,35 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
};
struct ArgmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
void forward() {
// B = softmax(A).
Argmax(&val_, &a_->val());
}
void backward() {
}
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
shape[1] = 1;
return shape;
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct LogNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogNodeOp(Args ...args)