diff --git a/src/definitions.h b/src/definitions.h new file mode 100644 index 00000000..5e3fb64d --- /dev/null +++ b/src/definitions.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include + +namespace marian { +typedef float Float; +} + +#include "keywords.h" +#include "tensor.h" + +namespace marian { + +typedef std::vector Shape; +const int whatevs{-1}; + +namespace keywords { + KEY(init, std::function) + KEY(axis, int) + KEY(name, std::string) + KEY(shape, Shape) + KEY(value, float) + KEY(lazy_shape, std::function) + KEY(lazy_value, std::function) +} + +} diff --git a/src/expression_operators.h b/src/expression_operators.h new file mode 100644 index 00000000..2d2ac18a --- /dev/null +++ b/src/expression_operators.h @@ -0,0 +1,185 @@ +#pragma once + +#include "graph.h" +#include "graph_operators.h" +#include "expressions.h" + +namespace marian { + +template +inline Expr data(Args ...args) { + return Expr(new DataNode(args...)); +} + +template +inline Expr param(Args ...args) { + return Expr(new ParamNode(args...)); +} +template +inline Expr constant(Args ...args) { + return Expr(new ConstantNode(args...)); +} + +template +inline Expr ones(Args ...args) { + return Expr(new ConstantNode(keywords::value=1, args...)); +} + +template +inline Expr zeroes(Args ...args) { + return Expr(new ConstantNode(keywords::value=0, args...)); +} + +/*********************************************************/ + +inline Expr sigmoid(Expr a) { + return Expr(new SigmoidNodeOp(a)); +} + +inline Expr tanh(Expr a) { + return Expr(new TanhNodeOp(a)); +} + +inline Expr log(Expr a) { + return Expr(new LogNodeOp(a)); +}; + +inline Expr exp(Expr a) { + return Expr(new ExpNodeOp(a)); +}; + +inline Expr operator-(Expr a) { + return Expr(new NegNodeOp(a)); +}; + +/*********************************************************/ + +inline Expr operator+(Expr a, Expr b) { + return Expr(new PlusNodeOp(a, b)); +} + +inline Expr operator-(Expr a, Expr b) { + return Expr(new MinusNodeOp(a, b)); +} + +inline Expr operator*(Expr a, Expr b) { + return Expr(new MultNodeOp(a, b)); +} + +inline Expr operator/(Expr a, Expr b) { + return Expr(new DivNodeOp(a, b)); +} + +inline Expr dot(Expr a, Expr b) { + return Expr(new DotNodeOp(a, b)); +} + +/******************************************************/ + +Expr broadcast(Shape shape, Expr a) { + if(a.val().shape() == shape) { + return a; + } + else { + size_t dimsA = a.val().shape().size(); + size_t dimsB = shape.size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensor and shape have different number of dimensions"); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = a.val().shape()[i]; + int dimB = shape[i]; + bool broadcastable = (dimA == dimB || dimA == 1); + UTIL_THROW_IF2(!broadcastable, + "Cannot broadcast tensor dimension " + << dimA << " to " << dimB); + if(dimA == 1 && dimB > 1) { + std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl; + if(i == 0) { + Expr one = ones(keywords::shape={shape[0], 1}); + a = dot(one, a); + } + else if(i == 1) { + Expr one = ones(keywords::shape={1, shape[1]}); + a = dot(a, one); + } + else { + UTIL_THROW2("Not implemented"); + } + } + } + return a; + } +} + +/*********************************************************/ + +// inefficient +template +inline Expr sum(Expr a, Args ...args) { + using namespace keywords; + Keywords params(args...); + int ax = params.Get(axis, whatevs); + + if(ax == 0) { + auto lshape = [&a]() -> Shape { + int rows = a.val().shape()[0]; + return {1, rows}; + }; + Expr one = ones(lazy_shape=lshape); + return dot(one, a); + } + else if(ax == 1) { + auto lshape = [&a]() -> Shape { + int cols = a.val().shape()[1]; + return {cols, 1}; + }; + Expr one = ones(lazy_shape=lshape); + return dot(a, one); + } + else if(ax == 2) { + UTIL_THROW2("Not implemented"); + } + else if(ax == 3) { + UTIL_THROW2("Not implemented"); + } + return sum(sum(a, axis=0), axis=1); +} + +// inefficient +template +inline Expr softmax(Expr a, Args ...args) { + Expr e = exp(a); + return e / sum(e, args...); +} + +// inefficient +template +inline Expr mean(Expr a, Args ...args) { + using namespace keywords; + Keywords params(args...); + size_t ax = params.Get(axis, whatevs); + + switch (ax) { + case 0: + return sum(a, axis=0) / constant(shape={1, 1}, + lazy_value=[&a]() -> Float { + return a.val().shape()[0]; + }); + case 1: + return sum(a, axis=1) / constant(shape={1, 1}, + lazy_value=[&a]() -> Float { + return a.val().shape()[1]; + }); + case 2: + UTIL_THROW2("Not implemented"); + case 3: + UTIL_THROW2("Not implemented"); + default: + return sum(a) / constant(shape={1, 1}, + lazy_value=[&a]() -> Float { + return a.val().size(); + }); + } +} + +} \ No newline at end of file diff --git a/src/expressions.h b/src/expressions.h new file mode 100644 index 00000000..b78acf79 --- /dev/null +++ b/src/expressions.h @@ -0,0 +1,55 @@ +#pragma once + +namespace marian { + +class Expr { + public: + Expr(Chainable* chainable) : pimpl_(chainable) {} + + Tensor val() { + return pimpl_->val(); + } + + Tensor grad() { + return pimpl_->grad(); + } + + ChainPtr pimpl() { + return pimpl_; + } + + void forward() { + UTIL_THROW_IF2(pimpl_.get() != stack.back(), + "Trying to call forward on non-root of computation graph"); + + std::cerr << "a" << std::endl; + for(auto&& v : stack) + v->allocate(); + + std::cerr << "f" << std::endl; + for(auto&& v : stack) + v->forward(); + } + + void backward() { + UTIL_THROW_IF2(pimpl_.get() != stack.back(), + "Trying to call backward on non-root of computation graph"); + + for(auto&& v : stack) + v->set_zero_adjoint(); + + typedef ChainableStack::reverse_iterator It; + pimpl_->init_dependent(); + for(It it = stack.rbegin(); it != stack.rend(); ++it) + (*it)->backward(); + } + + operator ChainPtr() { + return pimpl_; + } + + private: + ChainPtr pimpl_; +}; + +} \ No newline at end of file diff --git a/src/graph.h b/src/graph.h new file mode 100644 index 00000000..47313501 --- /dev/null +++ b/src/graph.h @@ -0,0 +1,83 @@ +#pragma once + +#include "keywords.h" +#include "tensor.h" + +namespace marian { + +template +struct Chainable { + Chainable() { } + virtual ~Chainable() { } + virtual void forward() { } + virtual void backward() { } + virtual void init_dependent() { } + virtual void set_zero_adjoint() { } + + virtual void allocate() = 0; + + virtual DataType val() = 0; + virtual DataType grad() = 0; +}; + +typedef std::vector*> ChainableStack; +typedef std::shared_ptr> ChainPtr; + +ChainableStack stack; + +class Node : public Chainable, + public keywords::Keywords { + public: + template + Node(Args ...args) + : Keywords(args...), + shape_(Get(keywords::shape, {1, 1})), + name_(Get(keywords::name, "none")) + { + std::cerr << "Creating node " << name_ << std::endl; + stack.push_back(this); + } + + virtual ~Node() {}; + + virtual void allocate() { + val_.allocate(shape_); + } + + virtual void init_dependent() { + if(adj_) { + adj_.set(1); + } + else { + adj_.allocate(shape_, 1); + } + } + + virtual void set_zero_adjoint() { + if(adj_) { + adj_.set(0); + } + else { + adj_.allocate(shape_, 0); + } + } + + virtual Tensor val() { + UTIL_THROW_IF2(!val_, "Tensor has not been allocated"); + return val_; + }; + + virtual Tensor grad() { + UTIL_THROW_IF2(!adj_, "Tensor has not been allocated"); + return adj_; + }; + + protected: + Shape shape_; + std::string name_; + + Tensor val_; + Tensor adj_; +}; + +} \ No newline at end of file diff --git a/src/graph_operators.h b/src/graph_operators.h new file mode 100644 index 00000000..e22dcf4f --- /dev/null +++ b/src/graph_operators.h @@ -0,0 +1,271 @@ +#pragma once + +#include "graph.h" +#include "expressions.h" +//#include "expression_operators.h" + +namespace marian { + +struct DataNode : public Node { + template + DataNode(Args ...args) + : Node(args...) { } + + void forward() {} + void backward() {} +}; + +struct ConstantNode : public Node { + template + ConstantNode(Args ...args) + : Node(args...) { } + + void forward() {} + void backward() {} +}; + +struct ParamNode : public Node { + template + ParamNode(Args ...args) + : Node(args...), + init_(Get>(keywords::init, [](Tensor){ })) + { } + + void forward() {} + void backward() {} + + virtual void allocate() { + val_.allocate(shape_); + init_(val_); + } + + private: + std::function init_; +}; + +struct UnaryNodeOp : public Node { + ChainPtr a_; + + template + UnaryNodeOp(ChainPtr a, Args ...args) + : Node(args...), a_(a) {} +}; + +struct SigmoidNodeOp : public UnaryNodeOp { + template + SigmoidNodeOp(Args ...args) + : UnaryNodeOp(args...) { } + + void forward() { + Element(_1 = Sigma(_2), + val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)), + a_->grad(), adj_, a_->val()); + } +}; + +struct TanhNodeOp : public UnaryNodeOp { + template + TanhNodeOp(Args ...args) + : UnaryNodeOp(args...) { } + + void forward() { + Element(_1 = Tanh(_2), + val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)), + a_->grad(), adj_, a_->val()); + } +}; + +struct LogNodeOp : public UnaryNodeOp { + template + LogNodeOp(Args ...args) + : UnaryNodeOp(args...) { + std::cerr << "log" << std::endl; + } + + void forward() { + Element(_1 = Log(_2), val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * 1.f / _3, + a_->grad(), adj_, a_->val()); + } +}; + +struct ExpNodeOp : public UnaryNodeOp { + template + ExpNodeOp(Args ...args) + : UnaryNodeOp(args...) { } + + void forward() { + Element(_1 = Exp(_2), val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * Exp(_3), + a_->grad(), adj_, a_->val()); + } +}; + +struct NegNodeOp : public UnaryNodeOp { + template + NegNodeOp(Args ...args) + : UnaryNodeOp(args...) { } + + void forward() { + Element(_1 = -_2, val_, a_->val()); + } + + void backward() { + Element(_1 += -_2, a_->grad(), adj_); + } +}; + +/******************************************************/ + +struct BinaryNodeOp : public Node { + ChainPtr a_; + ChainPtr b_; + + template + BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : Node(args...), a_(a), b_(b) {} +}; + +/*** Matrix Product ***/ + +struct DotNodeOp : public BinaryNodeOp { + template + DotNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, args...) { } + + Shape shape(ChainPtr a, ChainPtr b) { + UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0], + "matrix product requires dimensions to match"); + Shape shape1 = a->val().shape(); + Shape shape2 = b->val().shape(); + shape1[1] = shape2[1]; + return shape1; + } + + void forward() { + // C = A*B + Prod(val_, a_->val(), b_->val(), false, false); + } + + void backward() { + // D is the adjoint, the matrix of derivatives + // df/dA += D*B.T + // df/dB += A.T*D + // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C + // to sum gradients from different graph parts + Prod(a_->grad(), adj_, b_->val(), false, true, 1.0); + Prod(b_->grad(), a_->val(), adj_, true, false, 1.0); + } +}; + +Expr broadcast(Shape shape, Expr a); + +struct BroadcastingNodeOp : public BinaryNodeOp { + template + BroadcastingNodeOp(Expr a, Expr b, Args ...args) + : BinaryNodeOp(broadcast(shape(a ,b), a), + broadcast(shape(a ,b), b), + args...) {} + + static Shape shape(ChainPtr a, ChainPtr b) { + size_t dimsA = a->val().shape().size(); + size_t dimsB = b->val().shape().size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensors have different numbers of dimensions"); + Shape shape(dimsA); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = a->val().shape()[i]; + int dimB = b->val().shape()[i]; + bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); + UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " + << "operation cannot be broadcasted: " << dimA << " != " << dimB); + shape[i] = std::max(dimA, dimB); + } + return shape; + } +}; + + +struct PlusNodeOp : public BroadcastingNodeOp { + template + PlusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } + + void forward() { + Element(_1 = _2 + _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2, + a_->grad(), adj_); + Element(_1 += _2, + b_->grad(), adj_); + } +}; + +struct MinusNodeOp : public BroadcastingNodeOp { + template + MinusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } + + void forward() { + Element(_1 = _2 - _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2, + a_->grad(), adj_); + Element(_1 -= _2, + b_->grad(), adj_); + } +}; + +struct MultNodeOp : public BroadcastingNodeOp { + template + MultNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } + + void forward() { + Element(_1 = _2 * _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2 * _3, + a_->grad(), adj_, b_->val()); + Element(_1 += _2 * _3, + b_->grad(), adj_, a_->val()); + } +}; + +struct DivNodeOp : public BroadcastingNodeOp { + template + DivNodeOp(Args ...args) : BroadcastingNodeOp(args...) { } + + void forward() { + Element(_1 = _2 / _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2 * 1.0f / _3, + a_->grad(), adj_, b_->val()); + Element(_1 -= _2 * _3 / (_4 * _4), + b_->grad(), adj_, a_->val(), b_->val()); + } +}; + +} \ No newline at end of file diff --git a/src/operators.h b/src/operators.h deleted file mode 100644 index ef756dd2..00000000 --- a/src/operators.h +++ /dev/null @@ -1,361 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "marian.h" -#include "cudnn_tensor.h" - -namespace marian { - -/*** Unary operators ***/ - -struct UnaryNodeOp : public Node { - ChainPtr a_; - - UnaryNodeOp(const Tensor t, ChainPtr a) - : Node(t), a_(a) {} -}; - -struct SigmaNodeOp : public UnaryNodeOp { - SigmaNodeOp(ChainPtr a) - : UnaryNodeOp(Tensor(a->val().shape()), a) { } - - void forward() { - Element(_1 = Sigma(_2), - val_, a_->val()); - } - - void backward() { - Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)), - a_->grad(), adj_, a_->val()); - } -}; - -inline Var sigma(Var a) { - return Var(new SigmaNodeOp(a)); -} - -struct TanhNodeOp : public UnaryNodeOp { - TanhNodeOp(ChainPtr a) - : UnaryNodeOp(Tensor(a->val().shape()), a) { } - - void forward() { - Element(_1 = Tanh(_2), - val_, a_->val()); - } - - void backward() { - Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)), - a_->grad(), adj_, a_->val()); - } -}; - -inline Var tanh(Var a) { - return Var(new TanhNodeOp(a)); -} - -struct LogNodeOp : public UnaryNodeOp { - LogNodeOp(ChainPtr a) - : UnaryNodeOp(Tensor(a->val().shape()), a) { } - - void forward() { - Element(_1 = Log(_2), val_, a_->val()); - } - - void backward() { - Element(_1 += _2 * 1.f / _3, - a_->grad(), adj_, a_->val()); - } -}; - -inline Var log(Var a) { - return Var(new LogNodeOp(a)); -}; - -struct ExpNodeOp : public UnaryNodeOp { - ExpNodeOp(ChainPtr a) - : UnaryNodeOp(Tensor(a->val().shape()), a) { } - - void forward() { - Element(_1 = Exp(_2), val_, a_->val()); - } - - void backward() { - Element(_1 += _2 * Exp(_3), - a_->grad(), adj_, a_->val()); - } -}; - -inline Var exp(Var a) { - return Var(new ExpNodeOp(a)); -}; - -struct NegNodeOp : public UnaryNodeOp { - NegNodeOp(ChainPtr a) - : UnaryNodeOp(Tensor(a->val().shape()), a) { } - - void forward() { - Element(_1 = -_2, val_, a_->val()); - } - - void backward() { - Element(_1 += -_2, a_->grad(), adj_); - } -}; - -inline Var operator-(Var a) { - return Var(new NegNodeOp(a)); -}; - -/******************************************************/ - -struct BinaryNodeOp : public Node { - ChainPtr a_; - ChainPtr b_; - - BinaryNodeOp(const Tensor t, ChainPtr a, ChainPtr b) - : Node(t), a_(a), b_(b) {} -}; - -/*** Matrix Product ***/ - -struct DotNodeOp : public BinaryNodeOp { - DotNodeOp(ChainPtr a, ChainPtr b) : BinaryNodeOp(Tensor(shape(a, b)), a, b) { } - - Shape shape(ChainPtr a, ChainPtr b) { - UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0], - "matrix product requires dimensions to match"); - Shape shape1 = a->val().shape(); - Shape shape2 = b->val().shape(); - shape1[1] = shape2[1]; - return shape1; - } - - void forward() { - // C = A*B - Prod(val_, a_->val(), b_->val(), false, false); - } - - void backward() { - // D is the adjoint, the matrix of derivatives - // df/dA += D*B.T - // df/dB += A.T*D - // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C - // to sum gradients from different graph parts - Prod(a_->grad(), adj_, b_->val(), false, true, 1.0); - Prod(b_->grad(), a_->val(), adj_, true, false, 1.0); - } -}; - -inline Var dot(Var a, Var b) { - return Var(new DotNodeOp(a, b)); -} - -/******************************************************/ - -Var broadcast(Shape shape, Var a) { - if(a.val().shape() == shape) { - return a; - } - else { - size_t dimsA = a.val().shape().size(); - size_t dimsB = shape.size(); - UTIL_THROW_IF2(dimsA != dimsB, - "Tensor and shape have different number of dimensions"); - for(size_t i = 0; i < dimsA; ++i) { - int dimA = a.val().shape()[i]; - int dimB = shape[i]; - bool broadcastable = (dimA == dimB || dimA == 1); - UTIL_THROW_IF2(!broadcastable, - "Cannot broadcast tensor dimension " - << dimA << " to " << dimB); - if(dimA == 1 && dimB > 1) { - std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl; - if(i == 0) { - Var one = Tensor({shape[0], 1}, 1); - a = dot(one, a); - } - else if(i == 1) { - Var one = Tensor({1, shape[1]}, 1); - a = dot(a, one); - } - else { - UTIL_THROW2("Not implemented"); - } - } - } - return a; - } -} - -struct BroadcastingNodeOp : public BinaryNodeOp { - BroadcastingNodeOp(Var a, Var b) - : BroadcastingNodeOp(Tensor(shape(a ,b)), broadcast(shape(a ,b), a), broadcast(shape(a ,b), b)) {} - - static Shape shape(ChainPtr a, ChainPtr b) { - size_t dimsA = a->val().shape().size(); - size_t dimsB = b->val().shape().size(); - UTIL_THROW_IF2(dimsA != dimsB, - "Tensors have different numbers of dimensions"); - Shape shape(dimsA); - for(size_t i = 0; i < dimsA; ++i) { - int dimA = a->val().shape()[i]; - int dimB = b->val().shape()[i]; - bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); - UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " - << "operation cannot be broadcasted: " << dimA << " != " << dimB); - shape[i] = std::max(dimA, dimB); - } - return shape; - } - - private: - BroadcastingNodeOp(const Tensor t, ChainPtr a, ChainPtr b) - : BinaryNodeOp(t, a, b) {} -}; - -/*** Binary arithmetic ***/ - -/*** Plus ***/ - -struct PlusNodeOp : public BroadcastingNodeOp { - PlusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } - - void forward() { - Element(_1 = _2 + _3, - val_, a_->val(), b_->val()); - } - - void backward() { - Element(_1 += _2, - a_->grad(), adj_); - Element(_1 += _2, - b_->grad(), adj_); - } -}; - -inline Var operator+(Var a, Var b) { - return Var(new PlusNodeOp(a, b)); -} - -/*** Minus ***/ - -struct MinusNodeOp : public BroadcastingNodeOp { - MinusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } - - void forward() { - Element(_1 = _2 - _3, - val_, a_->val(), b_->val()); - } - - void backward() { - Element(_1 += _2, - a_->grad(), adj_); - Element(_1 -= _2, - b_->grad(), adj_); - } -}; - -inline Var operator-(Var a, Var b) { - return Var(new MinusNodeOp(a, b)); -} - -/*** Mult ***/ - -struct MultNodeOp : public BroadcastingNodeOp { - MultNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } - - void forward() { - Element(_1 = _2 * _3, - val_, a_->val(), b_->val()); - } - - void backward() { - Element(_1 += _2 * _3, - a_->grad(), adj_, b_->val()); - Element(_1 += _2 * _3, - b_->grad(), adj_, a_->val()); - } -}; - -inline Var operator*(Var a, Var b) { - return Var(new MultNodeOp(a, b)); -} - -/*** Division ***/ - -struct DivNodeOp : public BroadcastingNodeOp { - DivNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } - - void forward() { - Element(_1 = _2 / _3, - val_, a_->val(), b_->val()); - } - - void backward() { - Element(_1 += _2 * 1.0f / _3, - a_->grad(), adj_, b_->val()); - Element(_1 -= _2 * _3 / (_4 * _4), - b_->grad(), adj_, a_->val(), b_->val()); - } -}; - -inline Var operator/(Var a, Var b) { - return Var(new DivNodeOp(a, b)); -} - - -/*** Reductions ***/ - -enum Axis { undef, axis0, axis1, axis2, axis3 }; - -// inefficient -inline Var sum(Var a, Axis axis = Axis::undef) { - if(axis == Axis::axis0) { - int rows = a.val().shape()[0]; - int cols = a.val().shape()[1]; - Var one = Tensor({1, rows}, 1); - return dot(one, a); - } - else if(axis == Axis::axis1) { - int rows = a.val().shape()[0]; - int cols = a.val().shape()[1]; - Var one = Tensor({cols, 1}, 1); - return dot(a, one); - } - else if(axis == Axis::axis2) { - UTIL_THROW2("Not implemented"); - } - else if(axis == Axis::axis3) { - UTIL_THROW2("Not implemented"); - } - return sum(sum(a, Axis::axis0), Axis::axis1); -} - -// inefficient -inline Var softmax(Var a, Axis axis = Axis::undef) { - Var e = exp(a); - return e / sum(e, axis); -} - -// inefficient -inline Var mean(Var a, Axis axis = Axis::undef) { - switch (axis) { - case Axis::axis0: - return sum(a, axis) / a.val().shape()[0]; - case Axis::axis1: - return sum(a, axis) / a.val().shape()[1]; - case Axis::axis2: - UTIL_THROW2("Not implemented"); - case Axis::axis3: - UTIL_THROW2("Not implemented"); - case Axis::undef: - default: - return sum(a) / a.val().size(); - } -} - -} \ No newline at end of file diff --git a/src/tensor.h b/src/tensor.h new file mode 100644 index 00000000..99646ceb --- /dev/null +++ b/src/tensor.h @@ -0,0 +1,395 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "definitions.h" +#include "exception.h" +#include "thrust_functions.h" + +namespace marian { + +struct Handles { + cudnnHandle_t cudnnHandle; + cublasHandle_t cublasHandle; + + cudnnOpTensorDescriptor_t add; + + Handles() { + cudnnCreate(&cudnnHandle); + cublasCreate(&cublasHandle); + cudnnCreateOpTensorDescriptor(&add); + cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); + } + + ~Handles() { + cudnnDestroy(cudnnHandle); + cublasDestroy(cublasHandle); + cudnnDestroyOpTensorDescriptor(add); + } +}; + +Handles handles; + +typedef std::vector Shape; + +template +class TensorImpl { + private: + Shape shape_; + thrust::device_vector data_; + cudnnTensorDescriptor_t desc_; + size_t tno_; + static size_t tensorCounter; + + cudnnDataType_t dataType() { + switch(sizeof(Float)) { + case 2: return CUDNN_DATA_HALF; + case 8: return CUDNN_DATA_DOUBLE; + default: return CUDNN_DATA_FLOAT; + } + } + + public: + typedef Float value_type; + + TensorImpl(const Shape& shape, value_type value = 0) + : shape_(shape), tno_(tensorCounter++) + { + // @TODO: + UTIL_THROW_IF2(shape_.size() != 2, + "For now, only 2D Tensors, will be fixed later."); + + UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, + "Wrong number of dimensions: " << shape_.size()); + int size = std::accumulate(shape_.begin(), shape_.end(), + 1, std::multiplies()); + data_.resize(size, value); + cudnnCreateTensorDescriptor(&desc_); + switch (shape_.size()) { + case 1: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], 1, 1, 1); break; + case 2: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], 1, 1); break; + case 3: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], 1); break; + case 4: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], shape_[3]); break; + } + } + + TensorImpl(const TensorImpl&) = delete; + TensorImpl(TensorImpl&&) = delete; + + ~TensorImpl() { + cudnnDestroyTensorDescriptor(desc_); + } + + value_type operator[](size_t i) const { + return data_[i]; + } + + auto begin() -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto begin() const -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto end() -> decltype( data_.end() ) { + return data_.end(); + } + + auto end() const -> decltype( data_.end() ) { + return data_.end(); + } + + const Shape& shape() const { + return shape_; + } + + size_t size() const { + return data_.size(); + } + + value_type* data() { + return thrust::raw_pointer_cast(data_.data()); + } + + cudnnTensorDescriptor_t desc() const { + return desc_; + } + + size_t id() const { + return tno_; + } + + void set(value_type value) { + thrust::fill(data_.begin(), data_.end(), value); + } +}; + +template +size_t TensorImpl::tensorCounter = 0; + +class Tensor { + private: + std::shared_ptr> pimpl_; + + public: + typedef TensorImpl::value_type value_type; + + Tensor() {} + ~Tensor() {} + + void allocate(Shape shape, value_type value = 0) { + pimpl_.reset(new TensorImpl(shape, value)); + } + + value_type operator[](size_t i) const { + return (*pimpl_)[i]; + } + + size_t size() const { + return pimpl_->size(); + } + + value_type* data() { + return pimpl_->data(); + } + + const value_type* data() const { + return pimpl_->data(); + } + + auto begin() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto begin() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + const Shape& shape() const { + return pimpl_->shape(); + } + + cudnnTensorDescriptor_t desc() const { + return pimpl_->desc(); + } + + void set(value_type value) { + pimpl_->set(value); + } + + size_t id() const { + return pimpl_->id(); + } + + operator bool() { + return pimpl_ != nullptr; + } +}; + +Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) { + std::vector r(t.size()); + for(int i = 0; i < r.size(); i++) + r[i] = (Float(rand() % 2000) - 1000.0)/10000.0; + thrust::copy(r.begin(), r.end(), t.begin()); + return t; +}; + +using namespace thrust::placeholders; +#define MAX_THREADS 512 +#define MAX_BLOCKS 65535 + +template +__global__ void gElement(Functor functor, Float* out, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn = in + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, const Float* in2, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, + const Float* in2, const Float* in3, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + const Float* rowIn3 = in3 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); + } + } + } +} + +// @TODO add broadcasting + +template +void Element(Functor functor, Tensor Out) { + Float* d_out = Out.data(); + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In) { + Float* d_out = Out.data(); + const Float* d_in = In.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, const Tensor In2) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, + const Tensor In2, const Tensor In3) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + const Float* d_in3 = In3.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta) { + Float alpha = 1.0; + + size_t m = A.shape()[0]; + size_t k = A.shape()[1]; + if(transA) + std::swap(m, k); + + size_t l = B.shape()[0]; + size_t n = B.shape()[1]; + if(transB) + std::swap(l, n); + + size_t lda = A.shape()[1]; + size_t ldb = B.shape()[1]; + size_t ldc = B.shape()[1]; + + if(transB) + ldc = B.shape()[0]; + + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasSgemm(handle, opB, opA, + n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); + return C; +} + +Tensor Prod(Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta = 0) { + + return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); +} + +} \ No newline at end of file diff --git a/src/tensor_operators.h b/src/tensor_operators.h new file mode 100644 index 00000000..e69de29b