mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
towards lazy allocation
This commit is contained in:
parent
4b79f3e72a
commit
8eb641ad45
29
src/definitions.h
Normal file
29
src/definitions.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
typedef float Float;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "keywords.h"
|
||||||
|
#include "tensor.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
typedef std::vector<int> Shape;
|
||||||
|
const int whatevs{-1};
|
||||||
|
|
||||||
|
namespace keywords {
|
||||||
|
KEY(init, std::function<void(Tensor)>)
|
||||||
|
KEY(axis, int)
|
||||||
|
KEY(name, std::string)
|
||||||
|
KEY(shape, Shape)
|
||||||
|
KEY(value, float)
|
||||||
|
KEY(lazy_shape, std::function<Shape()>)
|
||||||
|
KEY(lazy_value, std::function<float()>)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
185
src/expression_operators.h
Normal file
185
src/expression_operators.h
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "graph.h"
|
||||||
|
#include "graph_operators.h"
|
||||||
|
#include "expressions.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr data(Args ...args) {
|
||||||
|
return Expr(new DataNode(args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr param(Args ...args) {
|
||||||
|
return Expr(new ParamNode(args...));
|
||||||
|
}
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr constant(Args ...args) {
|
||||||
|
return Expr(new ConstantNode(args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr ones(Args ...args) {
|
||||||
|
return Expr(new ConstantNode(keywords::value=1, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr zeroes(Args ...args) {
|
||||||
|
return Expr(new ConstantNode(keywords::value=0, args...));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************/
|
||||||
|
|
||||||
|
inline Expr sigmoid(Expr a) {
|
||||||
|
return Expr(new SigmoidNodeOp(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr tanh(Expr a) {
|
||||||
|
return Expr(new TanhNodeOp(a));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr log(Expr a) {
|
||||||
|
return Expr(new LogNodeOp(a));
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Expr exp(Expr a) {
|
||||||
|
return Expr(new ExpNodeOp(a));
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Expr operator-(Expr a) {
|
||||||
|
return Expr(new NegNodeOp(a));
|
||||||
|
};
|
||||||
|
|
||||||
|
/*********************************************************/
|
||||||
|
|
||||||
|
inline Expr operator+(Expr a, Expr b) {
|
||||||
|
return Expr(new PlusNodeOp(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr operator-(Expr a, Expr b) {
|
||||||
|
return Expr(new MinusNodeOp(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr operator*(Expr a, Expr b) {
|
||||||
|
return Expr(new MultNodeOp(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr operator/(Expr a, Expr b) {
|
||||||
|
return Expr(new DivNodeOp(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Expr dot(Expr a, Expr b) {
|
||||||
|
return Expr(new DotNodeOp(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************/
|
||||||
|
|
||||||
|
Expr broadcast(Shape shape, Expr a) {
|
||||||
|
if(a.val().shape() == shape) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
size_t dimsA = a.val().shape().size();
|
||||||
|
size_t dimsB = shape.size();
|
||||||
|
UTIL_THROW_IF2(dimsA != dimsB,
|
||||||
|
"Tensor and shape have different number of dimensions");
|
||||||
|
for(size_t i = 0; i < dimsA; ++i) {
|
||||||
|
int dimA = a.val().shape()[i];
|
||||||
|
int dimB = shape[i];
|
||||||
|
bool broadcastable = (dimA == dimB || dimA == 1);
|
||||||
|
UTIL_THROW_IF2(!broadcastable,
|
||||||
|
"Cannot broadcast tensor dimension "
|
||||||
|
<< dimA << " to " << dimB);
|
||||||
|
if(dimA == 1 && dimB > 1) {
|
||||||
|
std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl;
|
||||||
|
if(i == 0) {
|
||||||
|
Expr one = ones(keywords::shape={shape[0], 1});
|
||||||
|
a = dot(one, a);
|
||||||
|
}
|
||||||
|
else if(i == 1) {
|
||||||
|
Expr one = ones(keywords::shape={1, shape[1]});
|
||||||
|
a = dot(a, one);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
UTIL_THROW2("Not implemented");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************/
|
||||||
|
|
||||||
|
// inefficient
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr sum(Expr a, Args ...args) {
|
||||||
|
using namespace keywords;
|
||||||
|
Keywords params(args...);
|
||||||
|
int ax = params.Get<int>(axis, whatevs);
|
||||||
|
|
||||||
|
if(ax == 0) {
|
||||||
|
auto lshape = [&a]() -> Shape {
|
||||||
|
int rows = a.val().shape()[0];
|
||||||
|
return {1, rows};
|
||||||
|
};
|
||||||
|
Expr one = ones(lazy_shape=lshape);
|
||||||
|
return dot(one, a);
|
||||||
|
}
|
||||||
|
else if(ax == 1) {
|
||||||
|
auto lshape = [&a]() -> Shape {
|
||||||
|
int cols = a.val().shape()[1];
|
||||||
|
return {cols, 1};
|
||||||
|
};
|
||||||
|
Expr one = ones(lazy_shape=lshape);
|
||||||
|
return dot(a, one);
|
||||||
|
}
|
||||||
|
else if(ax == 2) {
|
||||||
|
UTIL_THROW2("Not implemented");
|
||||||
|
}
|
||||||
|
else if(ax == 3) {
|
||||||
|
UTIL_THROW2("Not implemented");
|
||||||
|
}
|
||||||
|
return sum(sum(a, axis=0), axis=1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// inefficient
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr softmax(Expr a, Args ...args) {
|
||||||
|
Expr e = exp(a);
|
||||||
|
return e / sum(e, args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// inefficient
|
||||||
|
template <typename ...Args>
|
||||||
|
inline Expr mean(Expr a, Args ...args) {
|
||||||
|
using namespace keywords;
|
||||||
|
Keywords params(args...);
|
||||||
|
size_t ax = params.Get<int>(axis, whatevs);
|
||||||
|
|
||||||
|
switch (ax) {
|
||||||
|
case 0:
|
||||||
|
return sum(a, axis=0) / constant(shape={1, 1},
|
||||||
|
lazy_value=[&a]() -> Float {
|
||||||
|
return a.val().shape()[0];
|
||||||
|
});
|
||||||
|
case 1:
|
||||||
|
return sum(a, axis=1) / constant(shape={1, 1},
|
||||||
|
lazy_value=[&a]() -> Float {
|
||||||
|
return a.val().shape()[1];
|
||||||
|
});
|
||||||
|
case 2:
|
||||||
|
UTIL_THROW2("Not implemented");
|
||||||
|
case 3:
|
||||||
|
UTIL_THROW2("Not implemented");
|
||||||
|
default:
|
||||||
|
return sum(a) / constant(shape={1, 1},
|
||||||
|
lazy_value=[&a]() -> Float {
|
||||||
|
return a.val().size();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
55
src/expressions.h
Normal file
55
src/expressions.h
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
class Expr {
|
||||||
|
public:
|
||||||
|
Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
|
||||||
|
|
||||||
|
Tensor val() {
|
||||||
|
return pimpl_->val();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor grad() {
|
||||||
|
return pimpl_->grad();
|
||||||
|
}
|
||||||
|
|
||||||
|
ChainPtr pimpl() {
|
||||||
|
return pimpl_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
UTIL_THROW_IF2(pimpl_.get() != stack.back(),
|
||||||
|
"Trying to call forward on non-root of computation graph");
|
||||||
|
|
||||||
|
std::cerr << "a" << std::endl;
|
||||||
|
for(auto&& v : stack)
|
||||||
|
v->allocate();
|
||||||
|
|
||||||
|
std::cerr << "f" << std::endl;
|
||||||
|
for(auto&& v : stack)
|
||||||
|
v->forward();
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
UTIL_THROW_IF2(pimpl_.get() != stack.back(),
|
||||||
|
"Trying to call backward on non-root of computation graph");
|
||||||
|
|
||||||
|
for(auto&& v : stack)
|
||||||
|
v->set_zero_adjoint();
|
||||||
|
|
||||||
|
typedef ChainableStack::reverse_iterator It;
|
||||||
|
pimpl_->init_dependent();
|
||||||
|
for(It it = stack.rbegin(); it != stack.rend(); ++it)
|
||||||
|
(*it)->backward();
|
||||||
|
}
|
||||||
|
|
||||||
|
operator ChainPtr() {
|
||||||
|
return pimpl_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ChainPtr pimpl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
83
src/graph.h
Normal file
83
src/graph.h
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "keywords.h"
|
||||||
|
#include "tensor.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
template <class DataType>
|
||||||
|
struct Chainable {
|
||||||
|
Chainable() { }
|
||||||
|
virtual ~Chainable() { }
|
||||||
|
virtual void forward() { }
|
||||||
|
virtual void backward() { }
|
||||||
|
virtual void init_dependent() { }
|
||||||
|
virtual void set_zero_adjoint() { }
|
||||||
|
|
||||||
|
virtual void allocate() = 0;
|
||||||
|
|
||||||
|
virtual DataType val() = 0;
|
||||||
|
virtual DataType grad() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::vector<Chainable<Tensor>*> ChainableStack;
|
||||||
|
typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
|
||||||
|
|
||||||
|
ChainableStack stack;
|
||||||
|
|
||||||
|
class Node : public Chainable<Tensor>,
|
||||||
|
public keywords::Keywords {
|
||||||
|
public:
|
||||||
|
template <typename ...Args>
|
||||||
|
Node(Args ...args)
|
||||||
|
: Keywords(args...),
|
||||||
|
shape_(Get<Shape>(keywords::shape, {1, 1})),
|
||||||
|
name_(Get<std::string>(keywords::name, "none"))
|
||||||
|
{
|
||||||
|
std::cerr << "Creating node " << name_ << std::endl;
|
||||||
|
stack.push_back(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~Node() {};
|
||||||
|
|
||||||
|
virtual void allocate() {
|
||||||
|
val_.allocate(shape_);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void init_dependent() {
|
||||||
|
if(adj_) {
|
||||||
|
adj_.set(1);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
adj_.allocate(shape_, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void set_zero_adjoint() {
|
||||||
|
if(adj_) {
|
||||||
|
adj_.set(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
adj_.allocate(shape_, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Tensor val() {
|
||||||
|
UTIL_THROW_IF2(!val_, "Tensor has not been allocated");
|
||||||
|
return val_;
|
||||||
|
};
|
||||||
|
|
||||||
|
virtual Tensor grad() {
|
||||||
|
UTIL_THROW_IF2(!adj_, "Tensor has not been allocated");
|
||||||
|
return adj_;
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Shape shape_;
|
||||||
|
std::string name_;
|
||||||
|
|
||||||
|
Tensor val_;
|
||||||
|
Tensor adj_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
271
src/graph_operators.h
Normal file
271
src/graph_operators.h
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "graph.h"
|
||||||
|
#include "expressions.h"
|
||||||
|
//#include "expression_operators.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
struct DataNode : public Node {
|
||||||
|
template <typename ...Args>
|
||||||
|
DataNode(Args ...args)
|
||||||
|
: Node(args...) { }
|
||||||
|
|
||||||
|
void forward() {}
|
||||||
|
void backward() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConstantNode : public Node {
|
||||||
|
template <typename ...Args>
|
||||||
|
ConstantNode(Args ...args)
|
||||||
|
: Node(args...) { }
|
||||||
|
|
||||||
|
void forward() {}
|
||||||
|
void backward() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ParamNode : public Node {
|
||||||
|
template <typename ...Args>
|
||||||
|
ParamNode(Args ...args)
|
||||||
|
: Node(args...),
|
||||||
|
init_(Get<std::function<void(Tensor)>>(keywords::init, [](Tensor){ }))
|
||||||
|
{ }
|
||||||
|
|
||||||
|
void forward() {}
|
||||||
|
void backward() {}
|
||||||
|
|
||||||
|
virtual void allocate() {
|
||||||
|
val_.allocate(shape_);
|
||||||
|
init_(val_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<void(Tensor)> init_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UnaryNodeOp : public Node {
|
||||||
|
ChainPtr a_;
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||||
|
: Node(args...), a_(a) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SigmoidNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
SigmoidNodeOp(Args ...args)
|
||||||
|
: UnaryNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = Sigma(_2),
|
||||||
|
val_, a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
|
||||||
|
a_->grad(), adj_, a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TanhNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
TanhNodeOp(Args ...args)
|
||||||
|
: UnaryNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = Tanh(_2),
|
||||||
|
val_, a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
|
||||||
|
a_->grad(), adj_, a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
LogNodeOp(Args ...args)
|
||||||
|
: UnaryNodeOp(args...) {
|
||||||
|
std::cerr << "log" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = Log(_2), val_, a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * 1.f / _3,
|
||||||
|
a_->grad(), adj_, a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExpNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
ExpNodeOp(Args ...args)
|
||||||
|
: UnaryNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = Exp(_2), val_, a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * Exp(_3),
|
||||||
|
a_->grad(), adj_, a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NegNodeOp : public UnaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
NegNodeOp(Args ...args)
|
||||||
|
: UnaryNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = -_2, val_, a_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += -_2, a_->grad(), adj_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/******************************************************/
|
||||||
|
|
||||||
|
struct BinaryNodeOp : public Node {
|
||||||
|
ChainPtr a_;
|
||||||
|
ChainPtr b_;
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||||
|
: Node(args...), a_(a), b_(b) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*** Matrix Product ***/
|
||||||
|
|
||||||
|
struct DotNodeOp : public BinaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||||
|
: BinaryNodeOp(a, b, args...) { }
|
||||||
|
|
||||||
|
Shape shape(ChainPtr a, ChainPtr b) {
|
||||||
|
UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0],
|
||||||
|
"matrix product requires dimensions to match");
|
||||||
|
Shape shape1 = a->val().shape();
|
||||||
|
Shape shape2 = b->val().shape();
|
||||||
|
shape1[1] = shape2[1];
|
||||||
|
return shape1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
// C = A*B
|
||||||
|
Prod(val_, a_->val(), b_->val(), false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
// D is the adjoint, the matrix of derivatives
|
||||||
|
// df/dA += D*B.T
|
||||||
|
// df/dB += A.T*D
|
||||||
|
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
|
||||||
|
// to sum gradients from different graph parts
|
||||||
|
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
||||||
|
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Expr broadcast(Shape shape, Expr a);
|
||||||
|
|
||||||
|
struct BroadcastingNodeOp : public BinaryNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
BroadcastingNodeOp(Expr a, Expr b, Args ...args)
|
||||||
|
: BinaryNodeOp(broadcast(shape(a ,b), a),
|
||||||
|
broadcast(shape(a ,b), b),
|
||||||
|
args...) {}
|
||||||
|
|
||||||
|
static Shape shape(ChainPtr a, ChainPtr b) {
|
||||||
|
size_t dimsA = a->val().shape().size();
|
||||||
|
size_t dimsB = b->val().shape().size();
|
||||||
|
UTIL_THROW_IF2(dimsA != dimsB,
|
||||||
|
"Tensors have different numbers of dimensions");
|
||||||
|
Shape shape(dimsA);
|
||||||
|
for(size_t i = 0; i < dimsA; ++i) {
|
||||||
|
int dimA = a->val().shape()[i];
|
||||||
|
int dimB = b->val().shape()[i];
|
||||||
|
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
|
||||||
|
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
|
||||||
|
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
|
||||||
|
shape[i] = std::max(dimA, dimB);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct PlusNodeOp : public BroadcastingNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
PlusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = _2 + _3,
|
||||||
|
val_, a_->val(), b_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2,
|
||||||
|
a_->grad(), adj_);
|
||||||
|
Element(_1 += _2,
|
||||||
|
b_->grad(), adj_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MinusNodeOp : public BroadcastingNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
MinusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = _2 - _3,
|
||||||
|
val_, a_->val(), b_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2,
|
||||||
|
a_->grad(), adj_);
|
||||||
|
Element(_1 -= _2,
|
||||||
|
b_->grad(), adj_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MultNodeOp : public BroadcastingNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
MultNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = _2 * _3,
|
||||||
|
val_, a_->val(), b_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * _3,
|
||||||
|
a_->grad(), adj_, b_->val());
|
||||||
|
Element(_1 += _2 * _3,
|
||||||
|
b_->grad(), adj_, a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DivNodeOp : public BroadcastingNodeOp {
|
||||||
|
template <typename ...Args>
|
||||||
|
DivNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
|
||||||
|
|
||||||
|
void forward() {
|
||||||
|
Element(_1 = _2 / _3,
|
||||||
|
val_, a_->val(), b_->val());
|
||||||
|
}
|
||||||
|
|
||||||
|
void backward() {
|
||||||
|
Element(_1 += _2 * 1.0f / _3,
|
||||||
|
a_->grad(), adj_, b_->val());
|
||||||
|
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||||
|
b_->grad(), adj_, a_->val(), b_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
361
src/operators.h
361
src/operators.h
@ -1,361 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <functional>
|
|
||||||
#include <vector>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "marian.h"
|
|
||||||
#include "cudnn_tensor.h"
|
|
||||||
|
|
||||||
namespace marian {
|
|
||||||
|
|
||||||
/*** Unary operators ***/
|
|
||||||
|
|
||||||
struct UnaryNodeOp : public Node {
|
|
||||||
ChainPtr a_;
|
|
||||||
|
|
||||||
UnaryNodeOp(const Tensor t, ChainPtr a)
|
|
||||||
: Node(t), a_(a) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SigmaNodeOp : public UnaryNodeOp {
|
|
||||||
SigmaNodeOp(ChainPtr a)
|
|
||||||
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = Sigma(_2),
|
|
||||||
val_, a_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
|
|
||||||
a_->grad(), adj_, a_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var sigma(Var a) {
|
|
||||||
return Var(new SigmaNodeOp(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TanhNodeOp : public UnaryNodeOp {
|
|
||||||
TanhNodeOp(ChainPtr a)
|
|
||||||
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = Tanh(_2),
|
|
||||||
val_, a_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
|
|
||||||
a_->grad(), adj_, a_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var tanh(Var a) {
|
|
||||||
return Var(new TanhNodeOp(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LogNodeOp : public UnaryNodeOp {
|
|
||||||
LogNodeOp(ChainPtr a)
|
|
||||||
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = Log(_2), val_, a_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * 1.f / _3,
|
|
||||||
a_->grad(), adj_, a_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var log(Var a) {
|
|
||||||
return Var(new LogNodeOp(a));
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ExpNodeOp : public UnaryNodeOp {
|
|
||||||
ExpNodeOp(ChainPtr a)
|
|
||||||
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = Exp(_2), val_, a_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * Exp(_3),
|
|
||||||
a_->grad(), adj_, a_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var exp(Var a) {
|
|
||||||
return Var(new ExpNodeOp(a));
|
|
||||||
};
|
|
||||||
|
|
||||||
struct NegNodeOp : public UnaryNodeOp {
|
|
||||||
NegNodeOp(ChainPtr a)
|
|
||||||
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = -_2, val_, a_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += -_2, a_->grad(), adj_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var operator-(Var a) {
|
|
||||||
return Var(new NegNodeOp(a));
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************/
|
|
||||||
|
|
||||||
struct BinaryNodeOp : public Node {
|
|
||||||
ChainPtr a_;
|
|
||||||
ChainPtr b_;
|
|
||||||
|
|
||||||
BinaryNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
|
|
||||||
: Node(t), a_(a), b_(b) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*** Matrix Product ***/
|
|
||||||
|
|
||||||
struct DotNodeOp : public BinaryNodeOp {
|
|
||||||
DotNodeOp(ChainPtr a, ChainPtr b) : BinaryNodeOp(Tensor(shape(a, b)), a, b) { }
|
|
||||||
|
|
||||||
Shape shape(ChainPtr a, ChainPtr b) {
|
|
||||||
UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0],
|
|
||||||
"matrix product requires dimensions to match");
|
|
||||||
Shape shape1 = a->val().shape();
|
|
||||||
Shape shape2 = b->val().shape();
|
|
||||||
shape1[1] = shape2[1];
|
|
||||||
return shape1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
// C = A*B
|
|
||||||
Prod(val_, a_->val(), b_->val(), false, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
// D is the adjoint, the matrix of derivatives
|
|
||||||
// df/dA += D*B.T
|
|
||||||
// df/dB += A.T*D
|
|
||||||
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
|
|
||||||
// to sum gradients from different graph parts
|
|
||||||
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
|
||||||
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var dot(Var a, Var b) {
|
|
||||||
return Var(new DotNodeOp(a, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/******************************************************/
|
|
||||||
|
|
||||||
Var broadcast(Shape shape, Var a) {
|
|
||||||
if(a.val().shape() == shape) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
size_t dimsA = a.val().shape().size();
|
|
||||||
size_t dimsB = shape.size();
|
|
||||||
UTIL_THROW_IF2(dimsA != dimsB,
|
|
||||||
"Tensor and shape have different number of dimensions");
|
|
||||||
for(size_t i = 0; i < dimsA; ++i) {
|
|
||||||
int dimA = a.val().shape()[i];
|
|
||||||
int dimB = shape[i];
|
|
||||||
bool broadcastable = (dimA == dimB || dimA == 1);
|
|
||||||
UTIL_THROW_IF2(!broadcastable,
|
|
||||||
"Cannot broadcast tensor dimension "
|
|
||||||
<< dimA << " to " << dimB);
|
|
||||||
if(dimA == 1 && dimB > 1) {
|
|
||||||
std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl;
|
|
||||||
if(i == 0) {
|
|
||||||
Var one = Tensor({shape[0], 1}, 1);
|
|
||||||
a = dot(one, a);
|
|
||||||
}
|
|
||||||
else if(i == 1) {
|
|
||||||
Var one = Tensor({1, shape[1]}, 1);
|
|
||||||
a = dot(a, one);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
UTIL_THROW2("Not implemented");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BroadcastingNodeOp : public BinaryNodeOp {
|
|
||||||
BroadcastingNodeOp(Var a, Var b)
|
|
||||||
: BroadcastingNodeOp(Tensor(shape(a ,b)), broadcast(shape(a ,b), a), broadcast(shape(a ,b), b)) {}
|
|
||||||
|
|
||||||
static Shape shape(ChainPtr a, ChainPtr b) {
|
|
||||||
size_t dimsA = a->val().shape().size();
|
|
||||||
size_t dimsB = b->val().shape().size();
|
|
||||||
UTIL_THROW_IF2(dimsA != dimsB,
|
|
||||||
"Tensors have different numbers of dimensions");
|
|
||||||
Shape shape(dimsA);
|
|
||||||
for(size_t i = 0; i < dimsA; ++i) {
|
|
||||||
int dimA = a->val().shape()[i];
|
|
||||||
int dimB = b->val().shape()[i];
|
|
||||||
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
|
|
||||||
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
|
|
||||||
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
|
|
||||||
shape[i] = std::max(dimA, dimB);
|
|
||||||
}
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
BroadcastingNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
|
|
||||||
: BinaryNodeOp(t, a, b) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*** Binary arithmetic ***/
|
|
||||||
|
|
||||||
/*** Plus ***/
|
|
||||||
|
|
||||||
struct PlusNodeOp : public BroadcastingNodeOp {
|
|
||||||
PlusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = _2 + _3,
|
|
||||||
val_, a_->val(), b_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2,
|
|
||||||
a_->grad(), adj_);
|
|
||||||
Element(_1 += _2,
|
|
||||||
b_->grad(), adj_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var operator+(Var a, Var b) {
|
|
||||||
return Var(new PlusNodeOp(a, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*** Minus ***/
|
|
||||||
|
|
||||||
struct MinusNodeOp : public BroadcastingNodeOp {
|
|
||||||
MinusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = _2 - _3,
|
|
||||||
val_, a_->val(), b_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2,
|
|
||||||
a_->grad(), adj_);
|
|
||||||
Element(_1 -= _2,
|
|
||||||
b_->grad(), adj_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var operator-(Var a, Var b) {
|
|
||||||
return Var(new MinusNodeOp(a, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*** Mult ***/
|
|
||||||
|
|
||||||
struct MultNodeOp : public BroadcastingNodeOp {
|
|
||||||
MultNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = _2 * _3,
|
|
||||||
val_, a_->val(), b_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * _3,
|
|
||||||
a_->grad(), adj_, b_->val());
|
|
||||||
Element(_1 += _2 * _3,
|
|
||||||
b_->grad(), adj_, a_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var operator*(Var a, Var b) {
|
|
||||||
return Var(new MultNodeOp(a, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*** Division ***/
|
|
||||||
|
|
||||||
struct DivNodeOp : public BroadcastingNodeOp {
|
|
||||||
DivNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
|
|
||||||
|
|
||||||
void forward() {
|
|
||||||
Element(_1 = _2 / _3,
|
|
||||||
val_, a_->val(), b_->val());
|
|
||||||
}
|
|
||||||
|
|
||||||
void backward() {
|
|
||||||
Element(_1 += _2 * 1.0f / _3,
|
|
||||||
a_->grad(), adj_, b_->val());
|
|
||||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
|
||||||
b_->grad(), adj_, a_->val(), b_->val());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
inline Var operator/(Var a, Var b) {
|
|
||||||
return Var(new DivNodeOp(a, b));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*** Reductions ***/
|
|
||||||
|
|
||||||
enum Axis { undef, axis0, axis1, axis2, axis3 };
|
|
||||||
|
|
||||||
// inefficient
|
|
||||||
inline Var sum(Var a, Axis axis = Axis::undef) {
|
|
||||||
if(axis == Axis::axis0) {
|
|
||||||
int rows = a.val().shape()[0];
|
|
||||||
int cols = a.val().shape()[1];
|
|
||||||
Var one = Tensor({1, rows}, 1);
|
|
||||||
return dot(one, a);
|
|
||||||
}
|
|
||||||
else if(axis == Axis::axis1) {
|
|
||||||
int rows = a.val().shape()[0];
|
|
||||||
int cols = a.val().shape()[1];
|
|
||||||
Var one = Tensor({cols, 1}, 1);
|
|
||||||
return dot(a, one);
|
|
||||||
}
|
|
||||||
else if(axis == Axis::axis2) {
|
|
||||||
UTIL_THROW2("Not implemented");
|
|
||||||
}
|
|
||||||
else if(axis == Axis::axis3) {
|
|
||||||
UTIL_THROW2("Not implemented");
|
|
||||||
}
|
|
||||||
return sum(sum(a, Axis::axis0), Axis::axis1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// inefficient
|
|
||||||
inline Var softmax(Var a, Axis axis = Axis::undef) {
|
|
||||||
Var e = exp(a);
|
|
||||||
return e / sum(e, axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// inefficient
|
|
||||||
inline Var mean(Var a, Axis axis = Axis::undef) {
|
|
||||||
switch (axis) {
|
|
||||||
case Axis::axis0:
|
|
||||||
return sum(a, axis) / a.val().shape()[0];
|
|
||||||
case Axis::axis1:
|
|
||||||
return sum(a, axis) / a.val().shape()[1];
|
|
||||||
case Axis::axis2:
|
|
||||||
UTIL_THROW2("Not implemented");
|
|
||||||
case Axis::axis3:
|
|
||||||
UTIL_THROW2("Not implemented");
|
|
||||||
case Axis::undef:
|
|
||||||
default:
|
|
||||||
return sum(a) / a.val().size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
395
src/tensor.h
Normal file
395
src/tensor.h
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <cudnn.h>
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <thrust/device_vector.h>
|
||||||
|
#include <thrust/functional.h>
|
||||||
|
|
||||||
|
#include "definitions.h"
|
||||||
|
#include "exception.h"
|
||||||
|
#include "thrust_functions.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
struct Handles {
|
||||||
|
cudnnHandle_t cudnnHandle;
|
||||||
|
cublasHandle_t cublasHandle;
|
||||||
|
|
||||||
|
cudnnOpTensorDescriptor_t add;
|
||||||
|
|
||||||
|
Handles() {
|
||||||
|
cudnnCreate(&cudnnHandle);
|
||||||
|
cublasCreate(&cublasHandle);
|
||||||
|
cudnnCreateOpTensorDescriptor(&add);
|
||||||
|
cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
~Handles() {
|
||||||
|
cudnnDestroy(cudnnHandle);
|
||||||
|
cublasDestroy(cublasHandle);
|
||||||
|
cudnnDestroyOpTensorDescriptor(add);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Handles handles;
|
||||||
|
|
||||||
|
typedef std::vector<int> Shape;
|
||||||
|
|
||||||
|
template<class Float>
|
||||||
|
class TensorImpl {
|
||||||
|
private:
|
||||||
|
Shape shape_;
|
||||||
|
thrust::device_vector<Float> data_;
|
||||||
|
cudnnTensorDescriptor_t desc_;
|
||||||
|
size_t tno_;
|
||||||
|
static size_t tensorCounter;
|
||||||
|
|
||||||
|
cudnnDataType_t dataType() {
|
||||||
|
switch(sizeof(Float)) {
|
||||||
|
case 2: return CUDNN_DATA_HALF;
|
||||||
|
case 8: return CUDNN_DATA_DOUBLE;
|
||||||
|
default: return CUDNN_DATA_FLOAT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef Float value_type;
|
||||||
|
|
||||||
|
TensorImpl(const Shape& shape, value_type value = 0)
|
||||||
|
: shape_(shape), tno_(tensorCounter++)
|
||||||
|
{
|
||||||
|
// @TODO:
|
||||||
|
UTIL_THROW_IF2(shape_.size() != 2,
|
||||||
|
"For now, only 2D Tensors, will be fixed later.");
|
||||||
|
|
||||||
|
UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
|
||||||
|
"Wrong number of dimensions: " << shape_.size());
|
||||||
|
int size = std::accumulate(shape_.begin(), shape_.end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
data_.resize(size, value);
|
||||||
|
cudnnCreateTensorDescriptor(&desc_);
|
||||||
|
switch (shape_.size()) {
|
||||||
|
case 1:
|
||||||
|
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
|
||||||
|
shape_[0], 1, 1, 1); break;
|
||||||
|
case 2:
|
||||||
|
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
|
||||||
|
shape_[0], shape_[1], 1, 1); break;
|
||||||
|
case 3:
|
||||||
|
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
|
||||||
|
shape_[0], shape_[1], shape_[2], 1); break;
|
||||||
|
case 4:
|
||||||
|
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
|
||||||
|
shape_[0], shape_[1], shape_[2], shape_[3]); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorImpl(const TensorImpl&) = delete;
|
||||||
|
TensorImpl(TensorImpl&&) = delete;
|
||||||
|
|
||||||
|
~TensorImpl() {
|
||||||
|
cudnnDestroyTensorDescriptor(desc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type operator[](size_t i) const {
|
||||||
|
return data_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin() -> decltype( data_.begin() ) {
|
||||||
|
return data_.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin() const -> decltype( data_.begin() ) {
|
||||||
|
return data_.begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end() -> decltype( data_.end() ) {
|
||||||
|
return data_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end() const -> decltype( data_.end() ) {
|
||||||
|
return data_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Shape& shape() const {
|
||||||
|
return shape_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return data_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type* data() {
|
||||||
|
return thrust::raw_pointer_cast(data_.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t desc() const {
|
||||||
|
return desc_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t id() const {
|
||||||
|
return tno_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(value_type value) {
|
||||||
|
thrust::fill(data_.begin(), data_.end(), value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Type>
|
||||||
|
size_t TensorImpl<Type>::tensorCounter = 0;
|
||||||
|
|
||||||
|
class Tensor {
|
||||||
|
private:
|
||||||
|
std::shared_ptr<TensorImpl<Float>> pimpl_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef TensorImpl<Float>::value_type value_type;
|
||||||
|
|
||||||
|
Tensor() {}
|
||||||
|
~Tensor() {}
|
||||||
|
|
||||||
|
void allocate(Shape shape, value_type value = 0) {
|
||||||
|
pimpl_.reset(new TensorImpl<Float>(shape, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type operator[](size_t i) const {
|
||||||
|
return (*pimpl_)[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return pimpl_->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
value_type* data() {
|
||||||
|
return pimpl_->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const value_type* data() const {
|
||||||
|
return pimpl_->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin() -> decltype( pimpl_->begin() ) {
|
||||||
|
return pimpl_->begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto begin() const -> decltype( pimpl_->begin() ) {
|
||||||
|
return pimpl_->begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end() -> decltype( pimpl_->begin() ) {
|
||||||
|
return pimpl_->begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end() const -> decltype( pimpl_->begin() ) {
|
||||||
|
return pimpl_->begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Shape& shape() const {
|
||||||
|
return pimpl_->shape();
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnnTensorDescriptor_t desc() const {
|
||||||
|
return pimpl_->desc();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(value_type value) {
|
||||||
|
pimpl_->set(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t id() const {
|
||||||
|
return pimpl_->id();
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() {
|
||||||
|
return pimpl_ != nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) {
|
||||||
|
std::vector<Float> r(t.size());
|
||||||
|
for(int i = 0; i < r.size(); i++)
|
||||||
|
r[i] = (Float(rand() % 2000) - 1000.0)/10000.0;
|
||||||
|
thrust::copy(r.begin(), r.end(), t.begin());
|
||||||
|
return t;
|
||||||
|
};
|
||||||
|
|
||||||
|
using namespace thrust::placeholders;
|
||||||
|
#define MAX_THREADS 512
|
||||||
|
#define MAX_BLOCKS 65535
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
__global__ void gElement(Functor functor, Float* out,
|
||||||
|
size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if(j < rows) {
|
||||||
|
Float* rowOut = out + j * cols;
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||||
|
int i = tid + threadIdx.x;
|
||||||
|
if(i < cols)
|
||||||
|
rowOut[i] = functor(rowOut[i]);;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
__global__ void gElement(Functor functor,
|
||||||
|
Float* out, const Float* in,
|
||||||
|
size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if(j < rows) {
|
||||||
|
Float* rowOut = out + j * cols;
|
||||||
|
const Float* rowIn = in + j * cols;
|
||||||
|
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||||
|
int i = tid + threadIdx.x;
|
||||||
|
if(i < cols)
|
||||||
|
rowOut[i] = functor(rowOut[i], rowIn[i]);;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
__global__ void gElement(Functor functor,
|
||||||
|
Float* out, const Float* in1, const Float* in2,
|
||||||
|
size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if(j < rows) {
|
||||||
|
Float* rowOut = out + j * cols;
|
||||||
|
const Float* rowIn1 = in1 + j * cols;
|
||||||
|
const Float* rowIn2 = in2 + j * cols;
|
||||||
|
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||||
|
int i = tid + threadIdx.x;
|
||||||
|
if(i < cols)
|
||||||
|
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
__global__ void gElement(Functor functor,
|
||||||
|
Float* out, const Float* in1,
|
||||||
|
const Float* in2, const Float* in3,
|
||||||
|
size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if(j < rows) {
|
||||||
|
Float* rowOut = out + j * cols;
|
||||||
|
const Float* rowIn1 = in1 + j * cols;
|
||||||
|
const Float* rowIn2 = in2 + j * cols;
|
||||||
|
const Float* rowIn3 = in3 + j * cols;
|
||||||
|
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||||
|
int i = tid + threadIdx.x;
|
||||||
|
if(i < cols)
|
||||||
|
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// @TODO add broadcasting
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
void Element(Functor functor, Tensor Out) {
|
||||||
|
Float* d_out = Out.data();
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||||
|
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||||
|
gElement<<<blocks, threads>>>(functor, d_out,
|
||||||
|
Out.shape()[0], Out.shape()[1]);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
void Element(Functor functor,
|
||||||
|
Tensor Out, const Tensor In) {
|
||||||
|
Float* d_out = Out.data();
|
||||||
|
const Float* d_in = In.data();
|
||||||
|
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||||
|
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||||
|
gElement<<<blocks, threads>>>(functor, d_out, d_in,
|
||||||
|
Out.shape()[0], Out.shape()[1]);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
void Element(Functor functor,
|
||||||
|
Tensor Out, const Tensor In1, const Tensor In2) {
|
||||||
|
|
||||||
|
Float* d_out = Out.data();
|
||||||
|
const Float* d_in1 = In1.data();
|
||||||
|
const Float* d_in2 = In2.data();
|
||||||
|
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||||
|
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||||
|
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
|
||||||
|
Out.shape()[0], Out.shape()[1]);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Functor>
|
||||||
|
void Element(Functor functor,
|
||||||
|
Tensor Out, const Tensor In1,
|
||||||
|
const Tensor In2, const Tensor In3) {
|
||||||
|
|
||||||
|
Float* d_out = Out.data();
|
||||||
|
const Float* d_in1 = In1.data();
|
||||||
|
const Float* d_in2 = In2.data();
|
||||||
|
const Float* d_in3 = In3.data();
|
||||||
|
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||||
|
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||||
|
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
|
||||||
|
Out.shape()[0], Out.shape()[1]);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||||
|
bool transA, bool transB, Float beta) {
|
||||||
|
Float alpha = 1.0;
|
||||||
|
|
||||||
|
size_t m = A.shape()[0];
|
||||||
|
size_t k = A.shape()[1];
|
||||||
|
if(transA)
|
||||||
|
std::swap(m, k);
|
||||||
|
|
||||||
|
size_t l = B.shape()[0];
|
||||||
|
size_t n = B.shape()[1];
|
||||||
|
if(transB)
|
||||||
|
std::swap(l, n);
|
||||||
|
|
||||||
|
size_t lda = A.shape()[1];
|
||||||
|
size_t ldb = B.shape()[1];
|
||||||
|
size_t ldc = B.shape()[1];
|
||||||
|
|
||||||
|
if(transB)
|
||||||
|
ldc = B.shape()[0];
|
||||||
|
|
||||||
|
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
|
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
|
|
||||||
|
cublasSgemm(handle, opB, opA,
|
||||||
|
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
|
||||||
|
return C;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
|
||||||
|
bool transA, bool transB, Float beta = 0) {
|
||||||
|
|
||||||
|
return Prod(handles.cublasHandle, C, A, B, transA, transB, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
0
src/tensor_operators.h
Normal file
0
src/tensor_operators.h
Normal file
Loading…
Reference in New Issue
Block a user