From cef8a37fc76b543af862a5dc9fca35b8be9ba43d Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sun, 28 Aug 2016 00:24:41 +0200 Subject: [PATCH] some clean-up --- CMakeLists.txt | 4 +- src/CMakeLists.txt | 2 + src/definitions.h | 26 +-- src/expression_operators.h | 51 ++--- src/expressions.cu | 53 +++++ src/expressions.h | 52 ++--- src/graph.h | 40 +++- src/graph_operators.h | 69 ++++--- src/keywords.h | 79 ++++++-- src/marian.h | 3 +- src/tensor.cu | 399 +++++++++++++++++++++++++++++++++++++ src/tensor.h | 199 +----------------- src/tensor_operators.cu | 40 ++++ src/tensor_operators.h | 151 ++++++++++++++ src/test.cu | 64 +++--- src/thrust_functions.h | 45 +++-- 16 files changed, 905 insertions(+), 372 deletions(-) create mode 100644 src/expressions.cu create mode 100644 src/tensor.cu create mode 100644 src/tensor_operators.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 42679a3d..1be00783 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) project(marian CXX) -SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated") -LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;) +SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated") +LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC') add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM) SET(CUDA_PROPAGATE_HOST_FLAGS OFF) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3d751b51..244977db 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,6 +8,8 @@ add_library(libcommon OBJECT cuda_add_executable( marian test.cu + expressions.cu + tensor_operators.cu $ ) diff --git a/src/definitions.h b/src/definitions.h index 5e3fb64d..ea52024e 100644 --- a/src/definitions.h +++ b/src/definitions.h @@ -5,25 +5,25 @@ #include namespace marian { -typedef float Float; + typedef float Float; + typedef std::vector Shape; + const int whatevs{-1}; } #include "keywords.h" #include "tensor.h" namespace marian { - -typedef std::vector Shape; -const int whatevs{-1}; + class Tensor; -namespace keywords { - KEY(init, std::function) - KEY(axis, int) - KEY(name, std::string) - KEY(shape, Shape) - KEY(value, float) - KEY(lazy_shape, std::function) - KEY(lazy_value, std::function) -} + namespace keywords { + KEY(axis, int) + KEY(name, std::string) + KEY(shape, Shape) + KEY(value, float) + KEY(lazy_shape, std::function) + KEY(lazy_value, std::function) + KEY(init, std::function) + } } diff --git a/src/expression_operators.h b/src/expression_operators.h index 2d2ac18a..3c3dc031 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -7,8 +7,8 @@ namespace marian { template -inline Expr data(Args ...args) { - return Expr(new DataNode(args...)); +inline Expr input(Args ...args) { + return Expr(new InputNode(args...)); } template @@ -76,30 +76,31 @@ inline Expr dot(Expr a, Expr b) { /******************************************************/ -Expr broadcast(Shape shape, Expr a) { - if(a.val().shape() == shape) { +Expr broadcast(Shape bShape, Expr a) { + const Shape& aShape = a.node()->shape(); + if(aShape == bShape) { return a; } else { - size_t dimsA = a.val().shape().size(); - size_t dimsB = shape.size(); + size_t dimsA = aShape.size(); + size_t dimsB = bShape.size(); UTIL_THROW_IF2(dimsA != dimsB, "Tensor and shape have different number of dimensions"); for(size_t i = 0; i < dimsA; ++i) { - int dimA = a.val().shape()[i]; - int dimB = shape[i]; + int dimA = aShape[i]; + int dimB = bShape[i]; bool broadcastable = (dimA == dimB || dimA == 1); UTIL_THROW_IF2(!broadcastable, "Cannot broadcast tensor dimension " << dimA << " to " << dimB); - if(dimA == 1 && dimB > 1) { + if(dimA == 1 && dimB != 1) { std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl; if(i == 0) { - Expr one = ones(keywords::shape={shape[0], 1}); + Expr one = ones(keywords::shape={bShape[0], 1}); a = dot(one, a); } else if(i == 1) { - Expr one = ones(keywords::shape={1, shape[1]}); + Expr one = ones(keywords::shape={1, bShape[1]}); a = dot(a, one); } else { @@ -120,20 +121,23 @@ inline Expr sum(Expr a, Args ...args) { Keywords params(args...); int ax = params.Get(axis, whatevs); + ChainPtr n = a.node(); if(ax == 0) { - auto lshape = [&a]() -> Shape { - int rows = a.val().shape()[0]; + auto lshape = [n]() -> Shape { + int rows = n->val().shape()[0]; return {1, rows}; }; - Expr one = ones(lazy_shape=lshape); + Expr one = ones(shape={1, n->shape()[0]}, + lazy_shape=lshape); return dot(one, a); } else if(ax == 1) { - auto lshape = [&a]() -> Shape { - int cols = a.val().shape()[1]; + auto lshape = [n]() -> Shape { + int cols = n->val().shape()[1]; return {cols, 1}; }; - Expr one = ones(lazy_shape=lshape); + Expr one = ones(shape={n->shape()[1], 1}, + lazy_shape=lshape); return dot(a, one); } else if(ax == 2) { @@ -159,16 +163,17 @@ inline Expr mean(Expr a, Args ...args) { Keywords params(args...); size_t ax = params.Get(axis, whatevs); + ChainPtr n = a.node(); switch (ax) { case 0: return sum(a, axis=0) / constant(shape={1, 1}, - lazy_value=[&a]() -> Float { - return a.val().shape()[0]; + lazy_value=[n]() -> Float { + return n->val().shape()[0]; }); case 1: return sum(a, axis=1) / constant(shape={1, 1}, - lazy_value=[&a]() -> Float { - return a.val().shape()[1]; + lazy_value=[n]() -> Float { + return n->val().shape()[1]; }); case 2: UTIL_THROW2("Not implemented"); @@ -176,8 +181,8 @@ inline Expr mean(Expr a, Args ...args) { UTIL_THROW2("Not implemented"); default: return sum(a) / constant(shape={1, 1}, - lazy_value=[&a]() -> Float { - return a.val().size(); + lazy_value=[n]() -> Float { + return n->val().size(); }); } } diff --git a/src/expressions.cu b/src/expressions.cu new file mode 100644 index 00000000..5065f697 --- /dev/null +++ b/src/expressions.cu @@ -0,0 +1,53 @@ +#include "expressions.h" +#include "graph_operators.h" + +namespace marian { + +Expr::Expr(Chainable* chainable) : pimpl_(chainable) {} +Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v, + keywords::shape={1,1})) {} + +Tensor Expr::val() { + return pimpl_->val(); +} + +Tensor Expr::grad() { + return pimpl_->grad(); +} + +ChainPtr Expr::node() { + return pimpl_; +} + +void Expr::forward(size_t batchSize) { + UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), + "Trying to call forward on non-root of computation graph"); + + std::cerr << "a" << std::endl; + for(auto&& v : Chainable::stack) { + v->allocate(batchSize); + } + + std::cerr << "f" << std::endl; + for(auto&& v : Chainable::stack) + v->forward(); +} + +void Expr::backward() { + UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), + "Trying to call backward on non-root of computation graph"); + + for(auto&& v : Chainable::stack) + v->set_zero_adjoint(); + + typedef typename Chainable::ChainableStack::reverse_iterator It; + pimpl_->init_dependent(); + for(It it = Chainable::stack.rbegin(); it != Chainable::stack.rend(); ++it) + (*it)->backward(); +} + +Expr::operator ChainPtr() { + return pimpl_; +} + +} \ No newline at end of file diff --git a/src/expressions.h b/src/expressions.h index b78acf79..90445603 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -1,52 +1,28 @@ #pragma once +#include "definitions.h" +#include "graph.h" + namespace marian { class Expr { public: - Expr(Chainable* chainable) : pimpl_(chainable) {} + Expr(Chainable* chainable); + Expr(Float v); - Tensor val() { - return pimpl_->val(); + Expr operator=(Tensor t) { + pimpl_->setVal(t); + return *this; } - Tensor grad() { - return pimpl_->grad(); - } + Tensor val(); + Tensor grad(); - ChainPtr pimpl() { - return pimpl_; - } + void forward(size_t batchSize); + void backward(); - void forward() { - UTIL_THROW_IF2(pimpl_.get() != stack.back(), - "Trying to call forward on non-root of computation graph"); - - std::cerr << "a" << std::endl; - for(auto&& v : stack) - v->allocate(); - - std::cerr << "f" << std::endl; - for(auto&& v : stack) - v->forward(); - } - - void backward() { - UTIL_THROW_IF2(pimpl_.get() != stack.back(), - "Trying to call backward on non-root of computation graph"); - - for(auto&& v : stack) - v->set_zero_adjoint(); - - typedef ChainableStack::reverse_iterator It; - pimpl_->init_dependent(); - for(It it = stack.rbegin(); it != stack.rend(); ++it) - (*it)->backward(); - } - - operator ChainPtr() { - return pimpl_; - } + ChainPtr node(); + operator ChainPtr(); private: ChainPtr pimpl_; diff --git a/src/graph.h b/src/graph.h index 47313501..15b4721d 100644 --- a/src/graph.h +++ b/src/graph.h @@ -14,16 +14,23 @@ struct Chainable { virtual void init_dependent() { } virtual void set_zero_adjoint() { } - virtual void allocate() = 0; + virtual void allocate(size_t) = 0; + virtual const Shape& shape() = 0; virtual DataType val() = 0; virtual DataType grad() = 0; + virtual void setVal(Tensor t) { + UTIL_THROW2("Tensors can only be assigned to input nodes"); + }; + + typedef std::vector*> ChainableStack; + static ChainableStack stack; }; -typedef std::vector*> ChainableStack; -typedef std::shared_ptr> ChainPtr; +template +typename Chainable::ChainableStack Chainable::stack; -ChainableStack stack; +typedef std::shared_ptr> ChainPtr; class Node : public Chainable, public keywords::Keywords { @@ -34,14 +41,27 @@ class Node : public Chainable, shape_(Get(keywords::shape, {1, 1})), name_(Get(keywords::name, "none")) { - std::cerr << "Creating node " << name_ << std::endl; stack.push_back(this); } virtual ~Node() {}; - virtual void allocate() { - val_.allocate(shape_); + virtual void allocate(size_t batchSize) { + for(auto&& d : shape_) { + if(d == whatevs) + d = batchSize; + } + if(Has(keywords::lazy_shape)) { + auto defaultShape = [this]() -> Shape { return shape_; }; + shape_ = Get>(keywords::lazy_shape, defaultShape)(); + } + if(Has(keywords::lazy_value)) + val_.allocate(shape_, Get>( + keywords::lazy_value, []()->Float{return 0.f;})()); + else if(Has(keywords::value)) + val_.allocate(shape_, Get(keywords::value, 0)); + else + val_.allocate(shape_); } virtual void init_dependent() { @@ -71,7 +91,11 @@ class Node : public Chainable, UTIL_THROW_IF2(!adj_, "Tensor has not been allocated"); return adj_; }; - + + virtual const Shape& shape() { + return shape_; + } + protected: Shape shape_; std::string name_; diff --git a/src/graph_operators.h b/src/graph_operators.h index e22dcf4f..d07c4b38 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -1,16 +1,26 @@ #pragma once -#include "graph.h" #include "expressions.h" -//#include "expression_operators.h" +#include "graph.h" +#include "tensor_operators.h" namespace marian { -struct DataNode : public Node { +struct InputNode : public Node { template - DataNode(Args ...args) - : Node(args...) { } + InputNode(Args ...args) + : Node(args...) { + UTIL_THROW_IF2(!Has(keywords::shape) && + !Has(keywords::lazy_shape), + "Data items require shape information"); + } + virtual void setVal(Tensor t) { + val_ = t; + shape_ = t.shape(); + //@todo, shape checking + }; + void forward() {} void backward() {} }; @@ -18,7 +28,11 @@ struct DataNode : public Node { struct ConstantNode : public Node { template ConstantNode(Args ...args) - : Node(args...) { } + : Node(args...) { + UTIL_THROW_IF2(!Has(keywords::shape) && + !Has(keywords::lazy_shape), + "Constant items require shape information"); + } void forward() {} void backward() {} @@ -29,12 +43,16 @@ struct ParamNode : public Node { ParamNode(Args ...args) : Node(args...), init_(Get>(keywords::init, [](Tensor){ })) - { } + { + UTIL_THROW_IF2(!Has(keywords::shape) && + !Has(keywords::lazy_shape), + "Param items require shape information"); + } void forward() {} void backward() {} - virtual void allocate() { + virtual void allocate(size_t batchSize) { val_.allocate(shape_); init_(val_); } @@ -86,9 +104,7 @@ struct TanhNodeOp : public UnaryNodeOp { struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args) - : UnaryNodeOp(args...) { - std::cerr << "log" << std::endl; - } + : UnaryNodeOp(args...) {} void forward() { Element(_1 = Log(_2), val_, a_->val()); @@ -145,13 +161,15 @@ struct BinaryNodeOp : public Node { struct DotNodeOp : public BinaryNodeOp { template DotNodeOp(ChainPtr a, ChainPtr b, Args ...args) - : BinaryNodeOp(a, b, args...) { } + : BinaryNodeOp(a, b, + keywords::shape=newShape(a,b), + args...) { } - Shape shape(ChainPtr a, ChainPtr b) { - UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0], + Shape newShape(ChainPtr a, ChainPtr b) { + Shape shape1 = a->shape(); + Shape shape2 = b->shape(); + UTIL_THROW_IF2(shape1[1] != shape2[0], "matrix product requires dimensions to match"); - Shape shape1 = a->val().shape(); - Shape shape2 = b->val().shape(); shape1[1] = shape2[1]; return shape1; } @@ -177,23 +195,26 @@ Expr broadcast(Shape shape, Expr a); struct BroadcastingNodeOp : public BinaryNodeOp { template BroadcastingNodeOp(Expr a, Expr b, Args ...args) - : BinaryNodeOp(broadcast(shape(a ,b), a), - broadcast(shape(a ,b), b), - args...) {} + : BinaryNodeOp(broadcast(newShape(a ,b), a), + broadcast(newShape(a ,b), b), + keywords::shape=newShape(a, b), + args...) {} - static Shape shape(ChainPtr a, ChainPtr b) { - size_t dimsA = a->val().shape().size(); - size_t dimsB = b->val().shape().size(); + static Shape newShape(ChainPtr a, ChainPtr b) { + size_t dimsA = a->shape().size(); + size_t dimsB = b->shape().size(); UTIL_THROW_IF2(dimsA != dimsB, "Tensors have different numbers of dimensions"); Shape shape(dimsA); for(size_t i = 0; i < dimsA; ++i) { - int dimA = a->val().shape()[i]; - int dimB = b->val().shape()[i]; + int dimA = a->shape()[i]; + int dimB = b->shape()[i]; bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " << "operation cannot be broadcasted: " << dimA << " != " << dimB); shape[i] = std::max(dimA, dimB); + if(dimA == whatevs || dimB == whatevs) + shape[i] = whatevs; } return shape; } diff --git a/src/keywords.h b/src/keywords.h index 522af5d2..e72cdb9a 100644 --- a/src/keywords.h +++ b/src/keywords.h @@ -15,24 +15,23 @@ namespace keywords { public: typedef Value value_type; - struct pair { - Keyword first; - Value second; - }; + Keyword(const std::string& name, Value value) + : name_(name), value_(value) {} Keyword(const std::string& name) - : name_(name) {} + : name_(name), value_() {} - pair operator=(Value value) { - return pair{*this, value}; + Keyword operator=(Value value) const { + return Keyword(name_, value); } - const std::string& operator()() const { - return name_; + const Value& operator()() const { + return value_; } private: - std::string name_; + const std::string name_; + const Value value_; }; struct Keywords { @@ -45,12 +44,12 @@ namespace keywords { template void add(Head head) { - map_[std::type_index(typeid(head.first))] = head.second; + map_[std::type_index(typeid(head))] = head(); } template void add(Head head, Tail ...tail) { - map_[std::type_index(typeid(head.first))] = head.second; + map_[std::type_index(typeid(head))] = head(); add(tail...); } @@ -63,12 +62,66 @@ namespace keywords { return default_value; } + template + bool Has(Key key) { + auto it = map_.find(std::type_index(typeid(key))); + return it != map_.end(); + } + private: std::unordered_map map_; }; + #include + +//template +//struct is_one_of { +// static constexpr bool value = false; +//}; +// +//template +//struct is_one_of { +// static constexpr bool value = +// std::is_same::value || is_one_of::value; +//}; +// +//template +//struct Index; +// +//template +//struct Index> { +// static constexpr std::size_t value = 0; +//}; +// +//template +//struct Index> { +// static constexpr std::size_t value = 1 + Index>::value; +//}; +// +//struct True {}; +//struct False {}; +// +//template +//typename Match::value_type opt(True foo, Args... args) { +// std::tuple t(args...); +// return std::get>::value>(t)(); +//} +// +//template +//typename Match::value_type opt(False foo, Args... args) { +// return typename Match::value_type(); +//} +// +//template +//typename Match::value_type Get(Args ...args) { +// constexpr bool match = is_one_of::value; +// typename std::conditional::type condition; +// return opt(condition, args...); +//} + + #define KEY(name, value_type) \ - typedef Keyword name ## _k; \ + typedef const Keyword name ## _k; \ name ## _k name(#name); } diff --git a/src/marian.h b/src/marian.h index 18510943..8c987ccf 100644 --- a/src/marian.h +++ b/src/marian.h @@ -5,5 +5,4 @@ #include "graph_operators.h" #include "expressions.h" #include "expression_operators.h" -//#include "tensor.h" -//#include "tensor_operators.h" + diff --git a/src/tensor.cu b/src/tensor.cu new file mode 100644 index 00000000..bf56ce27 --- /dev/null +++ b/src/tensor.cu @@ -0,0 +1,399 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "definitions.h" +#include "exception.h" +#include "thrust_functions.h" + +namespace marian { + +struct Handles { + cudnnHandle_t cudnnHandle; + cublasHandle_t cublasHandle; + + cudnnOpTensorDescriptor_t add; + + Handles() { + cudnnCreate(&cudnnHandle); + cublasCreate(&cublasHandle); + cudnnCreateOpTensorDescriptor(&add); + cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); + } + + ~Handles() { + cudnnDestroy(cudnnHandle); + cublasDestroy(cublasHandle); + cudnnDestroyOpTensorDescriptor(add); + } +}; + +Handles handles; + +typedef std::vector Shape; + +template +class TensorImpl { + private: + Shape shape_; + thrust::device_vector data_; + cudnnTensorDescriptor_t desc_; + size_t tno_; + static size_t tensorCounter; + + cudnnDataType_t dataType() { + switch(sizeof(Float)) { + case 2: return CUDNN_DATA_HALF; + case 8: return CUDNN_DATA_DOUBLE; + default: return CUDNN_DATA_FLOAT; + } + } + + public: + typedef Float value_type; + + TensorImpl(const Shape& shape, value_type value = 0) + : shape_(shape), tno_(tensorCounter++) + { + + // @TODO: + UTIL_THROW_IF2(shape_.size() != 2, + "For now, only 2D Tensors, will be fixed later."); + + UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, + "Wrong number of dimensions: " << shape_.size()); + + std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl; + + int size = std::accumulate(shape_.begin(), shape_.end(), + 1, std::multiplies()); + data_.resize(size, value); + cudnnCreateTensorDescriptor(&desc_); + switch (shape_.size()) { + case 1: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], 1, 1, 1); break; + case 2: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], 1, 1); break; + case 3: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], 1); break; + case 4: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], shape_[3]); break; + } + } + + TensorImpl(const TensorImpl&) = delete; + TensorImpl(TensorImpl&&) = delete; + + ~TensorImpl() { + cudnnDestroyTensorDescriptor(desc_); + } + + value_type operator[](size_t i) const { + return data_[i]; + } + + auto begin() -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto begin() const -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto end() -> decltype( data_.end() ) { + return data_.end(); + } + + auto end() const -> decltype( data_.end() ) { + return data_.end(); + } + + const Shape& shape() const { + return shape_; + } + + size_t size() const { + return data_.size(); + } + + value_type* data() { + return thrust::raw_pointer_cast(data_.data()); + } + + cudnnTensorDescriptor_t desc() const { + return desc_; + } + + size_t id() const { + return tno_; + } + + void set(value_type value) { + thrust::fill(data_.begin(), data_.end(), value); + } +}; + +template +size_t TensorImpl::tensorCounter = 0; + +class Tensor { + private: + std::shared_ptr> pimpl_; + + public: + typedef TensorImpl::value_type value_type; + + Tensor() {} + ~Tensor() {} + + void allocate(Shape shape, value_type value = 0) { + pimpl_.reset(new TensorImpl(shape, value)); + } + + value_type operator[](size_t i) const { + return (*pimpl_)[i]; + } + + size_t size() const { + return pimpl_->size(); + } + + value_type* data() { + return pimpl_->data(); + } + + const value_type* data() const { + return pimpl_->data(); + } + + auto begin() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto begin() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + const Shape& shape() const { + return pimpl_->shape(); + } + + cudnnTensorDescriptor_t desc() const { + return pimpl_->desc(); + } + + void set(value_type value) { + pimpl_->set(value); + } + + size_t id() const { + return pimpl_->id(); + } + + operator bool() { + return pimpl_ != nullptr; + } +}; + +Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) { + std::vector r(t.size()); + for(int i = 0; i < r.size(); i++) + r[i] = (Float(rand() % 2000) - 1000.0)/10000.0; + thrust::copy(r.begin(), r.end(), t.begin()); + return t; +}; + +using namespace thrust::placeholders; +#define MAX_THREADS 512 +#define MAX_BLOCKS 65535 + +template +__global__ void gElement(Functor functor, Float* out, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn = in + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, const Float* in2, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, + const Float* in2, const Float* in3, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + const Float* rowIn3 = in3 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); + } + } + } +} + +// @TODO add broadcasting + +template +void Element(Functor functor, Tensor Out) { + Float* d_out = Out.data(); + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In) { + Float* d_out = Out.data(); + const Float* d_in = In.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, const Tensor In2) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, + const Tensor In2, const Tensor In3) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + const Float* d_in3 = In3.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta) { + Float alpha = 1.0; + + size_t m = A.shape()[0]; + size_t k = A.shape()[1]; + if(transA) + std::swap(m, k); + + size_t l = B.shape()[0]; + size_t n = B.shape()[1]; + if(transB) + std::swap(l, n); + + size_t lda = A.shape()[1]; + size_t ldb = B.shape()[1]; + size_t ldc = B.shape()[1]; + + if(transB) + ldc = B.shape()[0]; + + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasSgemm(handle, opB, opA, + n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); + return C; +} + +Tensor Prod(Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta = 0) { + + return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); +} + +} \ No newline at end of file diff --git a/src/tensor.h b/src/tensor.h index 99646ceb..4b1e186f 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -1,10 +1,5 @@ #pragma once -#include -#include -#include -#include - #include #include #include @@ -36,7 +31,7 @@ struct Handles { } }; -Handles handles; +const Handles handles; typedef std::vector Shape; @@ -63,12 +58,16 @@ class TensorImpl { TensorImpl(const Shape& shape, value_type value = 0) : shape_(shape), tno_(tensorCounter++) { + // @TODO: UTIL_THROW_IF2(shape_.size() != 2, "For now, only 2D Tensors, will be fixed later."); UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, "Wrong number of dimensions: " << shape_.size()); + + std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl; + int size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); data_.resize(size, value); @@ -152,10 +151,15 @@ class Tensor { typedef TensorImpl::value_type value_type; Tensor() {} + Tensor(Shape shape, value_type value = 0) { + allocate(shape, value); + } + ~Tensor() {} void allocate(Shape shape, value_type value = 0) { - pimpl_.reset(new TensorImpl(shape, value)); + if(!pimpl_) + pimpl_.reset(new TensorImpl(shape, value)); } value_type operator[](size_t i) const { @@ -211,185 +215,4 @@ class Tensor { } }; -Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) { - std::vector r(t.size()); - for(int i = 0; i < r.size(); i++) - r[i] = (Float(rand() % 2000) - 1000.0)/10000.0; - thrust::copy(r.begin(), r.end(), t.begin()); - return t; -}; - -using namespace thrust::placeholders; -#define MAX_THREADS 512 -#define MAX_BLOCKS 65535 - -template -__global__ void gElement(Functor functor, Float* out, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i]);; - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn = in + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn[i]);; - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in1, const Float* in2, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn1 = in1 + j * cols; - const Float* rowIn2 = in2 + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in1, - const Float* in2, const Float* in3, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn1 = in1 + j * cols; - const Float* rowIn2 = in2 + j * cols; - const Float* rowIn3 = in3 + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); - } - } - } -} - -// @TODO add broadcasting - -template -void Element(Functor functor, Tensor Out) { - Float* d_out = Out.data(); - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In) { - Float* d_out = Out.data(); - const Float* d_in = In.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In1, const Tensor In2) { - - Float* d_out = Out.data(); - const Float* d_in1 = In1.data(); - const Float* d_in2 = In2.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in1, d_in2, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In1, - const Tensor In2, const Tensor In3) { - - Float* d_out = Out.data(); - const Float* d_in1 = In1.data(); - const Float* d_in2 = In2.data(); - const Float* d_in3 = In3.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, - bool transA, bool transB, Float beta) { - Float alpha = 1.0; - - size_t m = A.shape()[0]; - size_t k = A.shape()[1]; - if(transA) - std::swap(m, k); - - size_t l = B.shape()[0]; - size_t n = B.shape()[1]; - if(transB) - std::swap(l, n); - - size_t lda = A.shape()[1]; - size_t ldb = B.shape()[1]; - size_t ldc = B.shape()[1]; - - if(transB) - ldc = B.shape()[0]; - - cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - cublasSgemm(handle, opB, opA, - n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); - return C; -} - -Tensor Prod(Tensor C, const Tensor A, const Tensor B, - bool transA, bool transB, Float beta = 0) { - - return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); -} - } \ No newline at end of file diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu new file mode 100644 index 00000000..a8f72893 --- /dev/null +++ b/src/tensor_operators.cu @@ -0,0 +1,40 @@ +#include "tensor_operators.h" + +namespace marian { + +Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta) { + Float alpha = 1.0; + + size_t m = A.shape()[0]; + size_t k = A.shape()[1]; + if(transA) + std::swap(m, k); + + size_t l = B.shape()[0]; + size_t n = B.shape()[1]; + if(transB) + std::swap(l, n); + + size_t lda = A.shape()[1]; + size_t ldb = B.shape()[1]; + size_t ldc = B.shape()[1]; + + if(transB) + ldc = B.shape()[0]; + + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasSgemm(handle, opB, opA, + n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); + return C; +} + +Tensor Prod(Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta) { + + return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); +} + +} \ No newline at end of file diff --git a/src/tensor_operators.h b/src/tensor_operators.h index e69de29b..7ec4ca68 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -0,0 +1,151 @@ +#pragma once + +#include "tensor.h" + +namespace marian { + +using namespace thrust::placeholders; +#define MAX_THREADS 512 +#define MAX_BLOCKS 65535 + +template +__global__ void gElement(Functor functor, Float* out, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn = in + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, const Float* in2, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); + } + } + } +} + +template +__global__ void gElement(Functor functor, + Float* out, const Float* in1, + const Float* in2, const Float* in3, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + const Float* rowIn1 = in1 + j * cols; + const Float* rowIn2 = in2 + j * cols; + const Float* rowIn3 = in3 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); + } + } + } +} + +// @TODO add broadcasting + +template +void Element(Functor functor, Tensor Out) { + Float* d_out = Out.data(); + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In) { + Float* d_out = Out.data(); + const Float* d_in = In.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, const Tensor In2) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, + const Tensor In2, const Tensor In3) { + + Float* d_out = Out.data(); + const Float* d_in1 = In1.data(); + const Float* d_in2 = In2.data(); + const Float* d_in3 = In3.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta); + +Tensor Prod(Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, Float beta = 0); + +} diff --git a/src/test.cu b/src/test.cu index d35150b5..56156fee 100644 --- a/src/test.cu +++ b/src/test.cu @@ -6,50 +6,34 @@ int main(int argc, char** argv) { using namespace marian; using namespace keywords; - auto x = data(shape={whatevs, 784}, name="X"); - auto y = data(shape={whatevs, 10}, name="Y"); - + auto x = input(shape={whatevs, 784}, name="X"); + auto y = input(shape={whatevs, 10}, name="Y"); + auto w = param(shape={784, 10}, name="W0"); auto b = param(shape={1, 10}, name="b0"); - auto lr = softmax(dot(x, w) + b, axis=1); - auto cost = -mean(sum(y * log(lr), axis=1), axis=0); - - cost.forward(); + auto lr = softmax(dot(x, w) + b, axis=1, name="pred"); + auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - //auto set = [](size_t i, Expr c) { - // size_t bid = (i + 1) % batches; - // Tensor x = c["X"].val(); - // thrust::copy(XBatches[bid].begin(), XBatches[bid].end(), - // x.begin()); - // Tensor y = c["Y"].val(); - // thrust::copy(YBatches[bid].begin(), YBatches[bid].end(), - // y.begin()); - //}; - // - //auto before = [](size_t i, Expr c) { - // for(auto&& p : c.params()) - // clip(p.grad(), type=norm, max=10); - //}; - // - // - //float sum; - //auto after = [&sum](size_t i, Expr c) { - // sum += c.val()[0]; - // - // if(i % 100 == 0) { - // std::cerr << sum / i << std::endl; - // std::cerr << i << " : " << c.val()[0] << std::endl; - // } - // - // if(i % 10000 == 0) { - // std::cerr << "Saving model " << i << std::endl; - // std::stringstream name; - // name << "model.iter" << i << ".yml.gz"; - // dump(c, name.str()); - // } - // - //}; + Tensor tx({500, 784}, 1); + Tensor ty({500, 10}, 1); + + x = tx; + y = ty; + + graph.forward(500); + //std::cerr << graph["pred"].val()[0] << std::endl; + + + //hook0(graph); + //graph.autodiff(); + //std::cerr << graph["cost"].val()[0] << std::endl; + //hook1(graph); + //for(auto p : graph.params()) { + // auto update = _1 = _1 - alpha * _2; + // Element(update, p.val(), p.grad()); + //} + //hook2(graph); // //auto opt = adadelta(cost_function=cost, // eta=0.9, gamma=0.1, diff --git a/src/thrust_functions.h b/src/thrust_functions.h index a3013423..2712fda7 100644 --- a/src/thrust_functions.h +++ b/src/thrust_functions.h @@ -11,29 +11,19 @@ namespace thrust { namespace functional { - - // Ugly hacks, but it seems this is neccessary. - __host__ __device__ - float expf2(float x) { - float clip = 16; - if(x > clip) - x = clip; - if(x < -clip) - x = -clip; - return expf(x); - } - __host__ __device__ - float logf2(float x) { - if(x < 10e-10) - x = 10e-10; - return logf(x); - } - template struct unary_exp : public thrust::unary_function { __host__ __device__ - T operator()(const T &x) const { return expf2(x); } + T operator()(const T &x) const { + float x2 = x; + float clip = 16; + if(x2 > clip) + x2 = clip; + if(x2 < -clip) + x2 = -clip; + return expf(x2); + } }; template @@ -46,7 +36,12 @@ namespace thrust template struct unary_log : public thrust::unary_function { __host__ __device__ - T operator()(const T &x) const { return logf2(x); } + T operator()(const T &x) const { + float x2 = x; + if(x2 < 10e-10) + x2 = 10e-10; + return logf(x2); + } }; template @@ -59,7 +54,15 @@ namespace thrust template struct unary_sigma : public thrust::unary_function { __host__ __device__ - T operator()(const T &x) const { return 1.0 / (1.0 + expf2(-x)); } + T operator()(const T &x) const { + float x2 = x; + float clip = 16; + if(x2 > clip) + x2 = clip; + if(x2 < -clip) + x2 = -clip; + return 1.0 / (1.0 + expf(-x2)); + } }; template