From 976c8039db6547d42061076b7315eefe1a05ab79 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 15 Sep 2016 18:42:38 +0200 Subject: [PATCH] separated computation graphs --- src/CMakeLists.txt | 5 +- src/chainable.h | 34 +++++ src/expression_graph.cu | 41 ++++++ src/expression_graph.h | 96 ++++++++++++++ src/expression_operators.cu | 121 +++++++++++++++++ src/expression_operators.h | 128 +++--------------- src/expressions.cu | 59 --------- src/expressions.h | 33 ----- src/marian.h | 6 +- src/{graph.h => node.h} | 32 +---- src/{graph_operators.h => node_operators.h} | 139 ++++++++------------ src/param_initializers.h | 2 +- src/validate_mnist.cu | 22 ++-- 13 files changed, 385 insertions(+), 333 deletions(-) create mode 100644 src/chainable.h create mode 100644 src/expression_graph.cu create mode 100644 src/expression_graph.h create mode 100644 src/expression_operators.cu delete mode 100644 src/expressions.cu delete mode 100644 src/expressions.h rename src/{graph.h => node.h} (68%) rename src/{graph_operators.h => node_operators.h} (70%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cb121111..1ad22e9d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,9 +4,10 @@ include_directories(.) cuda_add_library(marian_lib cnpy/cnpy.cpp exception.cpp - expressions.cu - tensor.cu + expression_graph.cu + tensor.cu tensor_operators.cu + expression_operators.cu ) target_link_libraries(marian_lib) diff --git a/src/chainable.h b/src/chainable.h new file mode 100644 index 00000000..9fe6d208 --- /dev/null +++ b/src/chainable.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#include "exception.h" + +namespace marian { + +template +struct Chainable { + Chainable() { } + virtual ~Chainable() { } + virtual void forward() { } + virtual void backward() { } + virtual void init_dependent() { } + virtual void set_zero_adjoint() { } + + virtual void allocate(size_t) = 0; + + virtual const Shape& shape() = 0; + virtual DataType &val() = 0; + virtual DataType grad() = 0; + virtual void setVal(DataType t) { + UTIL_THROW2("Tensors can only be assigned to input nodes"); + }; +}; + +typedef std::vector*> ChainableStack; +typedef std::shared_ptr ChainableStackPtr; +typedef std::shared_ptr> ChainPtr; + + +} \ No newline at end of file diff --git a/src/expression_graph.cu b/src/expression_graph.cu new file mode 100644 index 00000000..61f8d2b5 --- /dev/null +++ b/src/expression_graph.cu @@ -0,0 +1,41 @@ +#include +#include "expression_graph.h" + +using namespace std; + +namespace marian { + +Expr::Expr(ExpressionGraphPtr g, Chainable* chainable) + : graph_(g), pimpl_(chainable) { + graph_->stack()->push_back(chainable); +} + +Tensor Expr::val() { + return pimpl_->val(); +} + +Tensor Expr::grad() { + return pimpl_->grad(); +} + +ChainPtr Expr::node() { + return pimpl_; +} + +ExpressionGraphPtr Expr::graph() { + return graph_; +} + +Expr::operator ChainPtr() { + return pimpl_; +} + +std::string Expr::Debug() const +{ + stringstream strm; + const Shape &shape = pimpl_->shape(); + strm << marian::Debug(shape); + return strm.str(); +} + +} diff --git a/src/expression_graph.h b/src/expression_graph.h new file mode 100644 index 00000000..dd62b449 --- /dev/null +++ b/src/expression_graph.h @@ -0,0 +1,96 @@ +#pragma once + +#include "definitions.h" +#include "chainable.h" +#include "node_operators.h" +#include "tensor.h" + +namespace marian { + +class ExpressionGraph; +typedef ExpressionGraph* ExpressionGraphPtr; + +class Expr { + public: + Expr(ExpressionGraphPtr g, Chainable* chainable); + + Expr operator=(Tensor t) { + pimpl_->setVal(t); + return *this; + } + + Tensor val(); + Tensor grad(); + + ExpressionGraphPtr graph(); + + ChainPtr node(); + operator ChainPtr(); + + std::string Debug() const; + + private: + ExpressionGraphPtr graph_; + ChainPtr pimpl_; +}; + +class ExpressionGraph { + public: + ExpressionGraph() + : stack_(new ChainableStack) + {} + + void forward(size_t batchSize) { + for(auto&& v : *stack_) { + v->allocate(batchSize); + } + for(auto&& v : *stack_) + v->forward(); + } + + void backward() { + for(auto&& v : *stack_) + v->set_zero_adjoint(); + + typedef typename ChainableStack::reverse_iterator It; + stack_->back()->init_dependent(); + for(It it = stack_->rbegin(); it != stack_->rend(); ++it) + (*it)->backward(); + } + + template + inline Expr input(Args ...args) { + return Expr(this, new InputNode(args...)); + } + + template + inline Expr param(Args ...args) { + return Expr(this, new ParamNode(args...)); + } + + template + inline Expr constant(Args ...args) { + return Expr(this, new ConstantNode(args...)); + } + + template + inline Expr ones(Args ...args) { + return Expr(this, new ConstantNode(keywords::value=1, args...)); + } + + template + inline Expr zeroes(Args ...args) { + return Expr(this, new ConstantNode(keywords::value=0, args...)); + } + + /*********************************************************/ + + ChainableStackPtr stack() { + return stack_; + } + + private: + ChainableStackPtr stack_; +}; + +} diff --git a/src/expression_operators.cu b/src/expression_operators.cu new file mode 100644 index 00000000..4819efce --- /dev/null +++ b/src/expression_operators.cu @@ -0,0 +1,121 @@ + +#include "expression_operators.h" +#include "node_operators.h" + +namespace marian { + +Expr logit(Expr a) { + return Expr(a.graph(), new LogitNodeOp(a)); +} + +Expr tanh(Expr a) { + return Expr(a.graph(), new TanhNodeOp(a)); +} + +Expr log(Expr a) { + return Expr(a.graph(), new LogNodeOp(a)); +}; + +Expr exp(Expr a) { + return Expr(a.graph(), new ExpNodeOp(a)); +}; + +Expr operator-(Expr a) { + return Expr(a.graph(), new NegNodeOp(a)); +}; + +/*********************************************************/ + +Expr broadcast(Shape bShape, Expr a) { + const Shape& aShape = a.node()->shape(); + if(aShape == bShape) { + return a; + } + else { + size_t dimsA = aShape.size(); + size_t dimsB = bShape.size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensor and shape have different number of dimensions"); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = aShape[i]; + int dimB = bShape[i]; + bool broadcastable = (dimA == dimB || dimA == 1); + UTIL_THROW_IF2(!broadcastable, + "Cannot broadcast tensor dimension " + << dimA << " to " << dimB); + if(dimA == 1 && dimB != 1) { + if(i == 0) { + Expr one = a.graph()->ones(keywords::shape={bShape[0], 1}); + a = dot(one, a); + } + else if(i == 1) { + Expr one = a.graph()->ones(keywords::shape={1, bShape[1]}); + a = dot(a, one); + } + else { + UTIL_THROW2("Not implemented"); + } + } + } + return a; + } +} + +static Shape newShape(ChainPtr a, ChainPtr b) { + size_t dimsA = a->shape().size(); + size_t dimsB = b->shape().size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensors have different numbers of dimensions"); + Shape shape(dimsA); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = a->shape()[i]; + int dimB = b->shape()[i]; + bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); + UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " + << "operation cannot be broadcasted: " << dimA << " != " << dimB); + shape[i] = std::max(dimA, dimB); + if(dimA == whatevs || dimB == whatevs) + shape[i] = whatevs; + } + return shape; +} + +Expr operator+(Expr a, Expr b) { + Shape shape = newShape(a, b); + Expr cast_a = broadcast(shape, a); + Expr cast_b = broadcast(shape, b); + return Expr(a.graph(), new PlusNodeOp(a, b)); +} + +Expr operator-(Expr a, Expr b) { + Shape shape = newShape(a, b); + Expr cast_a = broadcast(shape, a); + Expr cast_b = broadcast(shape, b); + return Expr(a.graph(), new MinusNodeOp(cast_a, cast_b)); +} + +Expr operator*(Expr a, Expr b) { + Shape shape = newShape(a, b); + Expr cast_a = broadcast(shape, a); + Expr cast_b = broadcast(shape, b); + return Expr(a.graph(), new MultNodeOp(cast_a, cast_b)); +} + +Expr operator/(Expr a, Expr b) { + Shape shape = newShape(a, b); + Expr cast_a = broadcast(shape, a); + Expr cast_b = broadcast(shape, b); + return Expr(a.graph(), new DivNodeOp(cast_a, cast_b)); +} + +Expr dot(Expr a, Expr b) { + Shape shape = newShape(a, b); + Expr cast_a = broadcast(shape, a); + Expr cast_b = broadcast(shape, b); + return Expr(a.graph(), new DotNodeOp(cast_a, cast_b)); +} + +/******************************************************/ + + +} diff --git a/src/expression_operators.h b/src/expression_operators.h index 253047d3..6082d28c 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -1,115 +1,34 @@ #pragma once -#include "graph.h" -#include "graph_operators.h" -#include "expressions.h" +#include "expression_graph.h" namespace marian { -template -inline Expr input(Args ...args) { - return Expr(new InputNode(args...)); -} +Expr logit(Expr a); -template -inline Expr param(Args ...args) { - return Expr(new ParamNode(args...)); -} -template -inline Expr constant(Args ...args) { - return Expr(new ConstantNode(args...)); -} +Expr tanh(Expr a); -template -inline Expr ones(Args ...args) { - return Expr(new ConstantNode(keywords::value=1, args...)); -} +Expr log(Expr a); -template -inline Expr zeroes(Args ...args) { - return Expr(new ConstantNode(keywords::value=0, args...)); -} +Expr exp(Expr a); + +Expr operator-(Expr a); /*********************************************************/ -inline Expr logit(Expr a) { - return Expr(new LogitNodeOp(a)); -} +Expr operator+(Expr a, Expr b); -inline Expr tanh(Expr a) { - return Expr(new TanhNodeOp(a)); -} +Expr operator-(Expr a, Expr b); -inline Expr log(Expr a) { - return Expr(new LogNodeOp(a)); -}; +Expr operator*(Expr a, Expr b); -inline Expr exp(Expr a) { - return Expr(new ExpNodeOp(a)); -}; +Expr operator/(Expr a, Expr b); -inline Expr operator-(Expr a) { - return Expr(new NegNodeOp(a)); -}; - -/*********************************************************/ - -inline Expr operator+(Expr a, Expr b) { - return Expr(new PlusNodeOp(a, b)); -} - -inline Expr operator-(Expr a, Expr b) { - return Expr(new MinusNodeOp(a, b)); -} - -inline Expr operator*(Expr a, Expr b) { - return Expr(new MultNodeOp(a, b)); -} - -inline Expr operator/(Expr a, Expr b) { - return Expr(new DivNodeOp(a, b)); -} - -inline Expr dot(Expr a, Expr b) { - return Expr(new DotNodeOp(a, b)); -} +Expr dot(Expr a, Expr b); /******************************************************/ -Expr broadcast(Shape bShape, Expr a) { - const Shape& aShape = a.node()->shape(); - if(aShape == bShape) { - return a; - } - else { - size_t dimsA = aShape.size(); - size_t dimsB = bShape.size(); - UTIL_THROW_IF2(dimsA != dimsB, - "Tensor and shape have different number of dimensions"); - for(size_t i = 0; i < dimsA; ++i) { - int dimA = aShape[i]; - int dimB = bShape[i]; - bool broadcastable = (dimA == dimB || dimA == 1); - UTIL_THROW_IF2(!broadcastable, - "Cannot broadcast tensor dimension " - << dimA << " to " << dimB); - if(dimA == 1 && dimB != 1) { - if(i == 0) { - Expr one = ones(keywords::shape={bShape[0], 1}); - a = dot(one, a); - } - else if(i == 1) { - Expr one = ones(keywords::shape={1, bShape[1]}); - a = dot(a, one); - } - else { - UTIL_THROW2("Not implemented"); - } - } - } - return a; - } -} +Expr broadcast(Shape bShape, Expr a); /*********************************************************/ @@ -126,7 +45,7 @@ inline Expr sum(Expr a, Args ...args) { int rows = n->val().shape()[0]; return {1, rows}; }; - Expr one = ones(shape={1, n->shape()[0]}, + Expr one = a.graph()->ones(shape={1, n->shape()[0]}, lazy_shape=lshape); return dot(one, a); } @@ -136,8 +55,8 @@ inline Expr sum(Expr a, Args ...args) { //std::cerr << "Shape will be " << cols << " by 1." << std::endl; return {cols, 1}; }; - Expr one = ones(shape={n->shape()[1], 1}, - lazy_shape=lshape); + Expr one = a.graph()->ones(shape={n->shape()[1], 1}, + lazy_shape=lshape); return dot(a, one); } else if(ax == 2) { @@ -164,18 +83,13 @@ inline Expr softmax(Expr a, Args ...args) { return {1,1}; }; using namespace keywords; - Expr one = ones(shape={1, 1}, lazy_shape=print_shape); + Expr one = a.graph()->ones(shape={1, 1}, lazy_shape=print_shape); #endif return e / sum(e, args...); } -template -inline Expr softmax_fast(Expr a, Args ...args) { - Expr e = Expr(new SoftmaxNodeOp(a, args...)); - return e; -} - +//inline Expr softmax_fast(Expr a, kaxis axis); // inefficient template @@ -187,12 +101,12 @@ inline Expr mean(Expr a, Args ...args) { ChainPtr n = a.node(); switch (ax) { case 0: - return sum(a, axis=0) / constant(shape={1, 1}, + return sum(a, axis=0) / a.graph()->constant(shape={1, 1}, lazy_value=[n]() -> Float { return n->val().shape()[0]; }); case 1: - return sum(a, axis=1) / constant(shape={1, 1}, + return sum(a, axis=1) / a.graph()->constant(shape={1, 1}, lazy_value=[n]() -> Float { return n->val().shape()[1]; }); @@ -201,7 +115,7 @@ inline Expr mean(Expr a, Args ...args) { case 3: UTIL_THROW2("Not implemented"); default: - return sum(a) / constant(shape={1, 1}, + return sum(a) / a.graph()->constant(shape={1, 1}, lazy_value=[n]() -> Float { return n->val().size(); }); diff --git a/src/expressions.cu b/src/expressions.cu deleted file mode 100644 index b2ff90ba..00000000 --- a/src/expressions.cu +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include "expressions.h" -#include "graph_operators.h" - -using namespace std; - -namespace marian { - -Expr::Expr(Chainable* chainable) : pimpl_(chainable) {} -Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v, - keywords::shape={1,1})) {} - -Tensor Expr::val() { - return pimpl_->val(); -} - -Tensor Expr::grad() { - return pimpl_->grad(); -} - -ChainPtr Expr::node() { - return pimpl_; -} - -void Expr::forward(size_t batchSize) { - UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), - "Trying to call forward on non-root of computation graph"); - for(auto&& v : Chainable::stack) { - v->allocate(batchSize); - } - for(auto&& v : Chainable::stack) - v->forward(); -} - -void Expr::backward() { - UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), - "Trying to call backward on non-root of computation graph"); - for(auto&& v : Chainable::stack) - v->set_zero_adjoint(); - - typedef typename Chainable::ChainableStack::reverse_iterator It; - pimpl_->init_dependent(); - for(It it = Chainable::stack.rbegin(); it != Chainable::stack.rend(); ++it) - (*it)->backward(); -} - -Expr::operator ChainPtr() { - return pimpl_; -} - -std::string Expr::Debug() const -{ - stringstream strm; - const Shape &shape = pimpl_->shape(); - strm << marian::Debug(shape); - return strm.str(); -} - -} diff --git a/src/expressions.h b/src/expressions.h deleted file mode 100644 index 43016dac..00000000 --- a/src/expressions.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include "definitions.h" -#include "graph.h" - -namespace marian { - -class Expr { - public: - Expr(Chainable* chainable); - Expr(Float v); - - Expr operator=(Tensor t) { - pimpl_->setVal(t); - return *this; - } - - Tensor val(); - Tensor grad(); - - void forward(size_t batchSize); - void backward(); - - ChainPtr node(); - operator ChainPtr(); - - std::string Debug() const; - - private: - ChainPtr pimpl_; -}; - -} diff --git a/src/marian.h b/src/marian.h index 0876d4cd..5cc06dd7 100644 --- a/src/marian.h +++ b/src/marian.h @@ -1,9 +1,7 @@ #pragma once #include "definitions.h" -#include "graph.h" -#include "graph_operators.h" -#include "expressions.h" -#include "expression_operators.h" +#include "expression_graph.h" #include "param_initializers.h" +#include "expression_operators.h" diff --git a/src/graph.h b/src/node.h similarity index 68% rename from src/graph.h rename to src/node.h index 329720b4..29d240cd 100644 --- a/src/graph.h +++ b/src/node.h @@ -2,36 +2,10 @@ #include "keywords.h" #include "tensor.h" +#include "chainable.h" namespace marian { -template -struct Chainable { - Chainable() { } - virtual ~Chainable() { } - virtual void forward() { } - virtual void backward() { } - virtual void init_dependent() { } - virtual void set_zero_adjoint() { } - - virtual void allocate(size_t) = 0; - - virtual const Shape& shape() = 0; - virtual DataType &val() = 0; - virtual DataType grad() = 0; - virtual void setVal(DataType t) { - UTIL_THROW2("Tensors can only be assigned to input nodes"); - }; - - typedef std::vector*> ChainableStack; - static ChainableStack stack; -}; - -template -typename Chainable::ChainableStack Chainable::stack; - -typedef std::shared_ptr> ChainPtr; - class Node : public Chainable, public keywords::Keywords { public: @@ -40,9 +14,7 @@ class Node : public Chainable, : Keywords(args...), shape_(Get(keywords::shape, {1, 1})), name_(Get(keywords::name, "none")) - { - stack.push_back(this); - } + { } virtual ~Node() {}; diff --git a/src/graph_operators.h b/src/node_operators.h similarity index 70% rename from src/graph_operators.h rename to src/node_operators.h index c7c0a057..54fe0bcd 100644 --- a/src/graph_operators.h +++ b/src/node_operators.h @@ -1,7 +1,6 @@ #pragma once -#include "expressions.h" -#include "graph.h" +#include "node.h" #include "tensor_operators.h" namespace marian { @@ -108,49 +107,14 @@ struct TanhNodeOp : public UnaryNodeOp { } }; -struct ArgmaxOp : public UnaryNodeOp { - template - ArgmaxOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...), - axis_(-1) { } - - Shape newShape(ChainPtr a, int axis) { - Shape shape1 = a->shape(); - UTIL_THROW_IF2(shape1.size() > 2, - "Tensors with more than 2 dimensions not supported yet"); - if(axis == 0) { - shape1[0] = 1; - } - else if(axis == 1) { - shape1[1] = 1; - } - else { - shape1 = {1, 1}; - } - return shape1; - } - - void forward() { - //val_ = Argmax(a_->val(), axis_); - UTIL_THROW2("Not implemented"); - } - - void backward() { - UTIL_THROW2("Not implemented"); - } - - private: - int axis_; -}; - // @TODO, make this numerically safe(r): // softmax(X) = softmax_safe(X - max(X, axis=1)) // Probably best to do this directly in Softmax // function. struct SoftmaxNodeOp : public UnaryNodeOp { template - SoftmaxNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, args...) { } + SoftmaxNodeOp(Args ...args) + : UnaryNodeOp(args...) { } void forward() { // B = softmax(A). @@ -172,8 +136,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp { struct LogNodeOp : public UnaryNodeOp { template - LogNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, args...) {} + LogNodeOp(Args ...args) + : UnaryNodeOp(args...) {} void forward() { Element(_1 = Log(_2), val_, a_->val()); @@ -187,8 +151,8 @@ struct LogNodeOp : public UnaryNodeOp { struct ExpNodeOp : public UnaryNodeOp { template - ExpNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, args...) { } + ExpNodeOp(Args ...args) + : UnaryNodeOp(args...) { } void forward() { Element(_1 = Exp(_2), val_, a_->val()); @@ -230,8 +194,9 @@ struct BinaryNodeOp : public Node { struct DotNodeOp : public BinaryNodeOp { template DotNodeOp(ChainPtr a, ChainPtr b, Args ...args) - : BinaryNodeOp(a, b, - keywords::shape=newShape(a,b), + : BinaryNodeOp( + a, b, + keywords::shape=newShape(a, b), args...) { } Shape newShape(ChainPtr a, ChainPtr b) { @@ -259,41 +224,40 @@ struct DotNodeOp : public BinaryNodeOp { } }; -Expr broadcast(Shape shape, Expr a); +//struct BroadcastingNodeOp : public BinaryNodeOp { +// template +// BroadcastingNodeOp(ChainPtr a, ChainPtr b, Args ...args) +// : BinaryNodeOp(broadcast(newShape(a ,b), a), +// broadcast(newShape(a ,b), b), +// keywords::shape=newShape(a, b), +// args...) {} +// +// static Shape newShape(ChainPtr a, ChainPtr b) { +// size_t dimsA = a->shape().size(); +// size_t dimsB = b->shape().size(); +// UTIL_THROW_IF2(dimsA != dimsB, +// "Tensors have different numbers of dimensions"); +// Shape shape(dimsA); +// for(size_t i = 0; i < dimsA; ++i) { +// int dimA = a->shape()[i]; +// int dimB = b->shape()[i]; +// bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); +// UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " +// << "operation cannot be broadcasted: " << dimA << " != " << dimB); +// shape[i] = std::max(dimA, dimB); +// if(dimA == whatevs || dimB == whatevs) +// shape[i] = whatevs; +// } +// return shape; +// } +//}; -struct BroadcastingNodeOp : public BinaryNodeOp { + +struct PlusNodeOp : public BinaryNodeOp { template - BroadcastingNodeOp(Expr a, Expr b, Args ...args) - : BinaryNodeOp(broadcast(newShape(a ,b), a), - broadcast(newShape(a ,b), b), - keywords::shape=newShape(a, b), - args...) {} - - static Shape newShape(ChainPtr a, ChainPtr b) { - size_t dimsA = a->shape().size(); - size_t dimsB = b->shape().size(); - UTIL_THROW_IF2(dimsA != dimsB, - "Tensors have different numbers of dimensions"); - Shape shape(dimsA); - for(size_t i = 0; i < dimsA; ++i) { - int dimA = a->shape()[i]; - int dimB = b->shape()[i]; - bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); - UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " - << "operation cannot be broadcasted: " << dimA << " != " << dimB); - shape[i] = std::max(dimA, dimB); - if(dimA == whatevs || dimB == whatevs) - shape[i] = whatevs; - } - return shape; - } -}; - - -struct PlusNodeOp : public BroadcastingNodeOp { - template - PlusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } - + PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { } + void forward() { Element(_1 = _2 + _3, val_, a_->val(), b_->val()); @@ -307,10 +271,11 @@ struct PlusNodeOp : public BroadcastingNodeOp { } }; -struct MinusNodeOp : public BroadcastingNodeOp { +struct MinusNodeOp : public BinaryNodeOp { template - MinusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } - + MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { } + void forward() { Element(_1 = _2 - _3, val_, a_->val(), b_->val()); @@ -324,10 +289,11 @@ struct MinusNodeOp : public BroadcastingNodeOp { } }; -struct MultNodeOp : public BroadcastingNodeOp { +struct MultNodeOp : public BinaryNodeOp { template - MultNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } - + MultNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { } + void forward() { Element(_1 = _2 * _3, val_, a_->val(), b_->val()); @@ -341,9 +307,10 @@ struct MultNodeOp : public BroadcastingNodeOp { } }; -struct DivNodeOp : public BroadcastingNodeOp { +struct DivNodeOp : public BinaryNodeOp { template - DivNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } + DivNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { } void forward() { Element(_1 = _2 / _3, diff --git a/src/param_initializers.h b/src/param_initializers.h index 04c6b48e..da8d3578 100644 --- a/src/param_initializers.h +++ b/src/param_initializers.h @@ -18,7 +18,7 @@ void ones(Tensor t) { } template -void distribution(Tensor t, float a=0.0, float b=0.1) { +void distribution(Tensor t, float a, float b) { std::random_device device; std::default_random_engine engine(device()); Distribution dist(a, b); diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 7d812e36..342ef26e 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -2,7 +2,6 @@ #include "marian.h" #include "mnist.h" #include "npz_converter.h" -#include "param_initializers.h" using namespace marian; using namespace keywords; @@ -31,13 +30,14 @@ int main(int argc, char** argv) { std::cerr << "Building model..."; - auto x = input(shape={whatevs, IMAGE_SIZE}); - auto y = input(shape={whatevs, LABEL_SIZE}); + ExpressionGraph g; + auto x = g.input(shape={whatevs, IMAGE_SIZE}, name="X"); + auto y = g.input(shape={whatevs, LABEL_SIZE}); - auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, - init=from_vector(wData)); - auto b = param(shape={1, LABEL_SIZE}, - init=from_vector(bData)); + auto w = g.param(shape={IMAGE_SIZE, LABEL_SIZE}, + init=from_vector(wData)); + auto b = g.param(shape={1, LABEL_SIZE}, + init=from_vector(bData)); auto probs = softmax(dot(x, w) + b, axis=1); auto cost = -mean(sum(y * log(probs), axis=1), axis=0); @@ -50,7 +50,7 @@ int main(int argc, char** argv) { x = xt << testImages; y = yt << testLabels; - cost.forward(BATCH_SIZE); + g.forward(BATCH_SIZE); std::vector results; results << probs.val(); @@ -66,17 +66,17 @@ int main(int argc, char** argv) { acc += (correct == proposed); } std::cerr << "Cost: " << cost.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; - + float eta = 0.1; for (size_t j = 0; j < 10; ++j) { for(size_t i = 0; i < 60; ++i) { - cost.backward(); + g.backward(); auto update_rule = _1 -= eta * _2; Element(update_rule, w.val(), w.grad()); Element(update_rule, b.val(), b.grad()); - cost.forward(BATCH_SIZE); + g.forward(BATCH_SIZE); } std::cerr << "Epoch: " << j << std::endl; std::vector results;