diff --git a/src/expressions.cu b/src/expressions.cu index 5065f697..a460a39e 100644 --- a/src/expressions.cu +++ b/src/expressions.cu @@ -1,6 +1,9 @@ +#include #include "expressions.h" #include "graph_operators.h" +using namespace std; + namespace marian { Expr::Expr(Chainable* chainable) : pimpl_(chainable) {} @@ -49,5 +52,14 @@ void Expr::backward() { Expr::operator ChainPtr() { return pimpl_; } + +std::string Expr::Debug() const +{ + stringstream strm; + //const Chainable &ct = *pimpl_; + const Shape &shape = pimpl_->shape(); + strm << marian::Debug(shape); + return strm.str(); +} -} \ No newline at end of file +} diff --git a/src/expressions.h b/src/expressions.h index 90445603..d7945f07 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -24,8 +24,10 @@ class Expr { ChainPtr node(); operator ChainPtr(); + std::string Debug() const; + private: ChainPtr pimpl_; }; -} \ No newline at end of file +} diff --git a/src/tensor.cu b/src/tensor.cu deleted file mode 100644 index bf56ce27..00000000 --- a/src/tensor.cu +++ /dev/null @@ -1,399 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "definitions.h" -#include "exception.h" -#include "thrust_functions.h" - -namespace marian { - -struct Handles { - cudnnHandle_t cudnnHandle; - cublasHandle_t cublasHandle; - - cudnnOpTensorDescriptor_t add; - - Handles() { - cudnnCreate(&cudnnHandle); - cublasCreate(&cublasHandle); - cudnnCreateOpTensorDescriptor(&add); - cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); - } - - ~Handles() { - cudnnDestroy(cudnnHandle); - cublasDestroy(cublasHandle); - cudnnDestroyOpTensorDescriptor(add); - } -}; - -Handles handles; - -typedef std::vector Shape; - -template -class TensorImpl { - private: - Shape shape_; - thrust::device_vector data_; - cudnnTensorDescriptor_t desc_; - size_t tno_; - static size_t tensorCounter; - - cudnnDataType_t dataType() { - switch(sizeof(Float)) { - case 2: return CUDNN_DATA_HALF; - case 8: return CUDNN_DATA_DOUBLE; - default: return CUDNN_DATA_FLOAT; - } - } - - public: - typedef Float value_type; - - TensorImpl(const Shape& shape, value_type value = 0) - : shape_(shape), tno_(tensorCounter++) - { - - // @TODO: - UTIL_THROW_IF2(shape_.size() != 2, - "For now, only 2D Tensors, will be fixed later."); - - UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, - "Wrong number of dimensions: " << shape_.size()); - - std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl; - - int size = std::accumulate(shape_.begin(), shape_.end(), - 1, std::multiplies()); - data_.resize(size, value); - cudnnCreateTensorDescriptor(&desc_); - switch (shape_.size()) { - case 1: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], 1, 1, 1); break; - case 2: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], 1, 1); break; - case 3: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], shape_[2], 1); break; - case 4: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], shape_[2], shape_[3]); break; - } - } - - TensorImpl(const TensorImpl&) = delete; - TensorImpl(TensorImpl&&) = delete; - - ~TensorImpl() { - cudnnDestroyTensorDescriptor(desc_); - } - - value_type operator[](size_t i) const { - return data_[i]; - } - - auto begin() -> decltype( data_.begin() ) { - return data_.begin(); - } - - auto begin() const -> decltype( data_.begin() ) { - return data_.begin(); - } - - auto end() -> decltype( data_.end() ) { - return data_.end(); - } - - auto end() const -> decltype( data_.end() ) { - return data_.end(); - } - - const Shape& shape() const { - return shape_; - } - - size_t size() const { - return data_.size(); - } - - value_type* data() { - return thrust::raw_pointer_cast(data_.data()); - } - - cudnnTensorDescriptor_t desc() const { - return desc_; - } - - size_t id() const { - return tno_; - } - - void set(value_type value) { - thrust::fill(data_.begin(), data_.end(), value); - } -}; - -template -size_t TensorImpl::tensorCounter = 0; - -class Tensor { - private: - std::shared_ptr> pimpl_; - - public: - typedef TensorImpl::value_type value_type; - - Tensor() {} - ~Tensor() {} - - void allocate(Shape shape, value_type value = 0) { - pimpl_.reset(new TensorImpl(shape, value)); - } - - value_type operator[](size_t i) const { - return (*pimpl_)[i]; - } - - size_t size() const { - return pimpl_->size(); - } - - value_type* data() { - return pimpl_->data(); - } - - const value_type* data() const { - return pimpl_->data(); - } - - auto begin() -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); - } - - auto begin() const -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); - } - - auto end() -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); - } - - auto end() const -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); - } - - const Shape& shape() const { - return pimpl_->shape(); - } - - cudnnTensorDescriptor_t desc() const { - return pimpl_->desc(); - } - - void set(value_type value) { - pimpl_->set(value); - } - - size_t id() const { - return pimpl_->id(); - } - - operator bool() { - return pimpl_ != nullptr; - } -}; - -Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) { - std::vector r(t.size()); - for(int i = 0; i < r.size(); i++) - r[i] = (Float(rand() % 2000) - 1000.0)/10000.0; - thrust::copy(r.begin(), r.end(), t.begin()); - return t; -}; - -using namespace thrust::placeholders; -#define MAX_THREADS 512 -#define MAX_BLOCKS 65535 - -template -__global__ void gElement(Functor functor, Float* out, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i]);; - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn = in + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn[i]);; - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in1, const Float* in2, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn1 = in1 + j * cols; - const Float* rowIn2 = in2 + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); - } - } - } -} - -template -__global__ void gElement(Functor functor, - Float* out, const Float* in1, - const Float* in2, const Float* in3, - size_t rows, size_t cols) { - for(int bid = 0; bid < rows; bid += gridDim.x) { - int j = bid + blockIdx.x; - if(j < rows) { - Float* rowOut = out + j * cols; - const Float* rowIn1 = in1 + j * cols; - const Float* rowIn2 = in2 + j * cols; - const Float* rowIn3 = in3 + j * cols; - - for(int tid = 0; tid < cols; tid += blockDim.x) { - int i = tid + threadIdx.x; - if(i < cols) - rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); - } - } - } -} - -// @TODO add broadcasting - -template -void Element(Functor functor, Tensor Out) { - Float* d_out = Out.data(); - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In) { - Float* d_out = Out.data(); - const Float* d_in = In.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In1, const Tensor In2) { - - Float* d_out = Out.data(); - const Float* d_in1 = In1.data(); - const Float* d_in2 = In2.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in1, d_in2, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -template -void Element(Functor functor, - Tensor Out, const Tensor In1, - const Tensor In2, const Tensor In3) { - - Float* d_out = Out.data(); - const Float* d_in1 = In1.data(); - const Float* d_in2 = In2.data(); - const Float* d_in3 = In3.data(); - - int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); - int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); - gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, - Out.shape()[0], Out.shape()[1]); - cudaStreamSynchronize(0); -} - -Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, - bool transA, bool transB, Float beta) { - Float alpha = 1.0; - - size_t m = A.shape()[0]; - size_t k = A.shape()[1]; - if(transA) - std::swap(m, k); - - size_t l = B.shape()[0]; - size_t n = B.shape()[1]; - if(transB) - std::swap(l, n); - - size_t lda = A.shape()[1]; - size_t ldb = B.shape()[1]; - size_t ldc = B.shape()[1]; - - if(transB) - ldc = B.shape()[0]; - - cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - cublasSgemm(handle, opB, opA, - n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); - return C; -} - -Tensor Prod(Tensor C, const Tensor A, const Tensor B, - bool transA, bool transB, Float beta = 0) { - - return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); -} - -} \ No newline at end of file diff --git a/src/tensor.h b/src/tensor.h index c9125b8c..2707a564 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -37,6 +37,16 @@ const Handles handles; typedef std::vector Shape; +inline std::string Debug(const Shape &shape) +{ + std::stringstream strm; + assert(shape.size()); + for (size_t i = 1; i < shape.size(); ++i) { + strm << "x" << shape[i]; + } + return strm.str(); +} + template class TensorImpl { private: @@ -145,10 +155,7 @@ class TensorImpl { { std::stringstream strm; assert(shape_.size()); - strm << "shape=" << shape_[0]; - for (size_t i = 1; i < shape_.size(); ++i) { - strm << "x" << shape_[i]; - } + strm << "shape=" << marian::Debug(shape_); return strm.str(); } }; diff --git a/src/test.cu b/src/test.cu index db3ec9d3..42933299 100644 --- a/src/test.cu +++ b/src/test.cu @@ -14,13 +14,17 @@ int main(int argc, char** argv) { Expr w = param(shape={784, 10}, name="W0"); Expr b = param(shape={1, 10}, name="b0"); - Expr lr = softmax(dot(x, w) + b, axis=1, name="pred"); + Expr n5 = dot(x, w); + Expr n6 = n5 + b; + Expr lr = softmax(n6, axis=1, name="pred"); + cerr << "lr=" << lr.Debug() << endl; + Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); Tensor tx({500, 784}, 1); Tensor ty({500, 10}, 1); - cerr << "tx=" << tx.Debug(); - cerr << "ty=" << ty.Debug(); + cerr << "tx=" << tx.Debug() << endl; + cerr << "ty=" << ty.Debug() << endl; x = tx; y = ty;