From 7fd950fbda443066f5c5ca24db2de35254de9164 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 7 May 2016 22:56:23 +0200 Subject: [PATCH] working basic training --- .gitignore | 36 ++++ CMakeLists.txt | 22 +++ src/CMakeLists.txt | 18 ++ src/cudnn_tensor.h | 400 +++++++++++++++++++++++++++++++++++++++++ src/exception.cpp | 108 +++++++++++ src/exception.h | 156 ++++++++++++++++ src/marian.h | 248 ++++++++----------------- src/operators.h | 370 ++++++++++++++++++++++++++++++++++++++ src/tensor.h | 117 ++++++++++++ src/test.cpp | 55 ------ src/test.cu | 71 ++++++++ src/thrust_functions.h | 95 ++++++++++ 12 files changed, 1467 insertions(+), 229 deletions(-) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 src/CMakeLists.txt create mode 100644 src/cudnn_tensor.h create mode 100644 src/exception.cpp create mode 100644 src/exception.h create mode 100644 src/operators.h create mode 100644 src/tensor.h delete mode 100644 src/test.cpp create mode 100644 src/test.cu create mode 100644 src/thrust_functions.h diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..266c3d72 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod + +# python compiled files +*.pyc + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# Temporaty files created by editors +.*.sw* + +build diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..2e9b9041 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.5.1) +set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) + +project(marian CXX) +SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated") +LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;) +add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM) +SET(CUDA_PROPAGATE_HOST_FLAGS OFF) + +include_directories(${amunn_SOURCE_DIR}) +find_package(CUDA REQUIRED) + +find_package(Boost COMPONENTS system timer) +if(Boost_FOUND) + include_directories(${Boost_INCLUDE_DIRS}) + set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES}) +else(Boost_FOUND) + message(SEND_ERROR "Cannot find Boost libraries. Terminating." ) +endif(Boost_FOUND) + +include_directories(${marian_SOURCE_DIR}/src) +add_subdirectory(src) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..3d751b51 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,18 @@ + +include_directories(.) + +add_library(libcommon OBJECT + exception.cpp +) + +cuda_add_executable( + marian + test.cu + $ +) + +foreach(exec marian) + target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) + cuda_add_cublas_to_target(${exec}) + set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") +endforeach(exec) diff --git a/src/cudnn_tensor.h b/src/cudnn_tensor.h new file mode 100644 index 00000000..cd71d942 --- /dev/null +++ b/src/cudnn_tensor.h @@ -0,0 +1,400 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "exception.h" +#include "thrust_functions.h" + +namespace marian { + +struct Handles { + cudnnHandle_t cudnnHandle; + cublasHandle_t cublasHandle; + + cudnnOpTensorDescriptor_t add; + + Handles() { + cudnnCreate(&cudnnHandle); + cublasCreate(&cublasHandle); + cudnnCreateOpTensorDescriptor(&add); + cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); + } + + ~Handles() { + cudnnDestroy(cudnnHandle); + cublasDestroy(cublasHandle); + cudnnDestroyOpTensorDescriptor(add); + } +}; + +Handles handles; + +typedef std::vector Shape; + +template +class TensorImpl { + private: + Shape shape_; + thrust::device_vector data_; + cudnnTensorDescriptor_t desc_; + size_t tno_; + static size_t tensorCounter; + + cudnnDataType_t dataType() { + switch(sizeof(Float)) { + case 2: return CUDNN_DATA_HALF; + case 8: return CUDNN_DATA_DOUBLE; + default: return CUDNN_DATA_FLOAT; + } + } + + public: + typedef Float value_type; + + TensorImpl(const Shape& shape, value_type value = 0) + : shape_(shape), tno_(tensorCounter++) + { + // @TODO: + UTIL_THROW_IF2(shape_.size() != 2, + "For now, only 2D Tensors, will be fixed later."); + + UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, + "Wrong number of dimensions: " << shape_.size()); + int size = std::accumulate(shape_.begin(), shape_.end(), + 1, std::multiplies()); + data_.resize(size, value); + cudnnCreateTensorDescriptor(&desc_); + switch (shape_.size()) { + case 1: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], 1, 1, 1); break; + case 2: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], 1, 1); break; + case 3: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], 1); break; + case 4: + cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + shape_[0], shape_[1], shape_[2], shape_[3]); break; + } + } + + TensorImpl(const TensorImpl&) = delete; + TensorImpl(TensorImpl&&) = delete; + + ~TensorImpl() { + cudnnDestroyTensorDescriptor(desc_); + } + + value_type operator[](size_t i) const { + return data_[i]; + } + + auto begin() -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto begin() const -> decltype( data_.begin() ) { + return data_.begin(); + } + + auto end() -> decltype( data_.end() ) { + return data_.end(); + } + + auto end() const -> decltype( data_.end() ) { + return data_.end(); + } + + const Shape& shape() const { + return shape_; + } + + size_t size() const { + return data_.size(); + } + + value_type* data() { + return thrust::raw_pointer_cast(data_.data()); + } + + cudnnTensorDescriptor_t desc() const { + return desc_; + } + + size_t id() const { + return tno_; + } + + void set(value_type value) { + thrust::fill(data_.begin(), data_.end(), value); + } +}; + +template +size_t TensorImpl::tensorCounter = 0; + +class Tensor { + private: + std::shared_ptr> pimpl_; + + public: + typedef TensorImpl::value_type value_type; + + Tensor(const Shape& shape, value_type value = 0) + : pimpl_(new TensorImpl(shape, value)) {} + + // Single value with broadcasting super powers. Might be + // worth getting rid of this performance-wise, but is saves + // so much typing when defining operators. + Tensor(value_type value) + : pimpl_(new TensorImpl({1, 1}, value)) {} + + Tensor() {} + + ~Tensor() {} + + value_type operator[](size_t i) const { + return (*pimpl_)[i]; + } + + size_t size() const { + return pimpl_->size(); + } + + value_type* data() { + return pimpl_->data(); + } + + const value_type* data() const { + return pimpl_->data(); + } + + auto begin() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto begin() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + auto end() const -> decltype( pimpl_->begin() ) { + return pimpl_->begin(); + } + + const Shape& shape() const { + return pimpl_->shape(); + } + + cudnnTensorDescriptor_t desc() const { + return pimpl_->desc(); + } + + void set(value_type value) { + pimpl_->set(value); + } + + size_t id() const { + return pimpl_->id(); + } + + operator bool() { + return pimpl_ != nullptr; + } +}; + +Tensor uniform(Tensor t, float a=-0.1, float b=0.1) { + std::vector r(t.size()); + for(int i = 0; i < r.size(); i++) + r[i] = (float(rand() % 2000) - 1000.0)/10000.0; + thrust::copy(r.begin(), r.end(), t.begin()); + return t; +}; + +using namespace thrust::placeholders; +#define MAX_THREADS 512 +#define MAX_BLOCKS 65535 + +template +__global__ void gElement(Functor functor, float* out, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* rowOut = out + j * cols; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + float* out, const float* in, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* rowOut = out + j * cols; + const float* rowIn = in + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn[i]);; + } + } + } +} + +template +__global__ void gElement(Functor functor, + float* out, const float* in1, const float* in2, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* rowOut = out + j * cols; + const float* rowIn1 = in1 + j * cols; + const float* rowIn2 = in2 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]); + } + } + } +} + +template +__global__ void gElement(Functor functor, + float* out, const float* in1, + const float* in2, const float* in3, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* rowOut = out + j * cols; + const float* rowIn1 = in1 + j * cols; + const float* rowIn2 = in2 + j * cols; + const float* rowIn3 = in3 + j * cols; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) + rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]); + } + } + } +} + +// @TODO add broadcasting + +template +void Element(Functor functor, Tensor Out) { + float* d_out = Out.data(); + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In) { + float* d_out = Out.data(); + const float* d_in = In.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, const Tensor In2) { + + float* d_out = Out.data(); + const float* d_in1 = In1.data(); + const float* d_in2 = In2.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +template +void Element(Functor functor, + Tensor Out, const Tensor In1, + const Tensor In2, const Tensor In3) { + + float* d_out = Out.data(); + const float* d_in1 = In1.data(); + const float* d_in2 = In2.data(); + const float* d_in3 = In3.data(); + + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gElement<<>>(functor, d_out, d_in1, d_in2, d_in3, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, float beta) { + float alpha = 1.0; + + size_t m = A.shape()[0]; + size_t k = A.shape()[1]; + if(transA) + std::swap(m, k); + + size_t l = B.shape()[0]; + size_t n = B.shape()[1]; + if(transB) + std::swap(l, n); + + size_t lda = A.shape()[1]; + size_t ldb = B.shape()[1]; + size_t ldc = B.shape()[1]; + + if(transB) + ldc = B.shape()[0]; + + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasSgemm(handle, opB, opA, + n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); + return C; +} + +Tensor Prod(Tensor C, const Tensor A, const Tensor B, + bool transA, bool transB, float beta = 0) { + + return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); +} + +} \ No newline at end of file diff --git a/src/exception.cpp b/src/exception.cpp new file mode 100644 index 00000000..453fcf66 --- /dev/null +++ b/src/exception.cpp @@ -0,0 +1,108 @@ +#include "exception.h" + +#ifdef __GXX_RTTI +#include +#endif + +#include +#include + +#if defined(_WIN32) || defined(_WIN64) +#include +#include +#endif + +namespace util { + +Exception::Exception() throw() {} +Exception::~Exception() throw() {} + +Exception::Exception(const Exception& o) throw() { + what_.str(o.what_.str()); +} + +void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) { + /* The child class might have set some text, but we want this to come first. + * Another option would be passing this information to the constructor, but + * then child classes would have to accept constructor arguments and pass + * them down. + */ + std::string old_text = what_.str(); + what_.str(std::string()); + what_ << file << ':' << line; + if (func) what_ << " in " << func << " threw "; + if (child_name) { + what_ << child_name; + } else { +#ifdef __GXX_RTTI + what_ << typeid(this).name(); +#else + what_ << "an exception"; +#endif + } + if (condition) { + what_ << " because `" << condition << '\''; + } + what_ << ".\n"; + what_ << old_text; +} + +namespace { + +#ifdef __GNUC__ +const char *HandleStrerror(int ret, const char *buf) __attribute__ ((unused)); +const char *HandleStrerror(const char *ret, const char * /*buf*/) __attribute__ ((unused)); +#endif +// At least one of these functions will not be called. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#endif +// The XOPEN version. +const char *HandleStrerror(int ret, const char *buf) { + if (!ret) return buf; + return NULL; +} + +// The GNU version. +const char *HandleStrerror(const char *ret, const char * /*buf*/) { + return ret; +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif +} // namespace + +ErrnoException::ErrnoException() throw() : errno_(errno) { + char buf[200]; + buf[0] = 0; +#if defined(sun) || defined(_WIN32) || defined(_WIN64) + const char *add = strerror(errno); +#else + const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf); +#endif + + if (add) { + *this << add << ' '; + } +} + +ErrnoException::~ErrnoException() throw() {} + +OverflowException::OverflowException() throw() {} +OverflowException::~OverflowException() throw() {} + +#if defined(_WIN32) || defined(_WIN64) +WindowsException::WindowsException() throw() { + unsigned int last_error = GetLastError(); + char error_msg[256] = ""; + if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, last_error, LANG_NEUTRAL, error_msg, sizeof(error_msg), NULL)) { + *this << "Windows error " << GetLastError() << " while formatting Windows error " << last_error << ". "; + } else { + *this << "Windows error " << last_error << ": " << error_msg; + } +} +WindowsException::~WindowsException() throw() {} +#endif + +} // namespace util diff --git a/src/exception.h b/src/exception.h new file mode 100644 index 00000000..85827d8c --- /dev/null +++ b/src/exception.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace util { + +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data); + +class Exception : public std::exception { + public: + Exception() throw(); + virtual ~Exception() throw(); + Exception(const Exception& o) throw(); + + const char *what() const throw() { return what_.str().c_str(); } + + // For use by the UTIL_THROW macros. + void SetLocation( + const char *file, + unsigned int line, + const char *func, + const char *child_name, + const char *condition); + + private: + template friend typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data); + + // This helps restrict operator<< defined below. + template struct ExceptionTag { + typedef T Identity; + }; + + std::stringstream what_; +}; + +/* This implements the normal operator<< for Exception and all its children. + * SFINAE means it only applies to Exception. Think of this as an ersatz + * boost::enable_if. + */ +template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data) { + e.what_ << data; + return e; +} + +#ifdef __GNUC__ +#define UTIL_FUNC_NAME __PRETTY_FUNCTION__ +#else +#ifdef _WIN32 +#define UTIL_FUNC_NAME __FUNCTION__ +#else +#define UTIL_FUNC_NAME NULL +#endif +#endif + +/* Create an instance of Exception, add the message Modify, and throw it. + * Modify is appended to the what() message and can contain << for ostream + * operations. + * + * do .. while kludge to swallow trailing ; character + * http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html . + * Arg can be a constructor argument to the exception. + */ +#define UTIL_THROW_BACKEND(Condition, Exception, Arg, Modify) do { \ + Exception UTIL_e Arg; \ + UTIL_e.SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, #Exception, Condition); \ + UTIL_e << Modify; \ + throw UTIL_e; \ +} while (0) + +#define UTIL_THROW_ARG(Exception, Arg, Modify) \ + UTIL_THROW_BACKEND(NULL, Exception, Arg, Modify) + +#define UTIL_THROW(Exception, Modify) \ + UTIL_THROW_BACKEND(NULL, Exception, , Modify); + +#define UTIL_THROW2(Modify) \ + UTIL_THROW_BACKEND(NULL, util::Exception, , Modify); + +#if __GNUC__ >= 3 +#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#else +#define UTIL_UNLIKELY(x) (x) +#endif + +#if __GNUC__ >= 3 +#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1) +#else +#define UTIL_LIKELY(x) (x) +#endif + +#define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \ + if (UTIL_UNLIKELY(Condition)) { \ + UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \ + } \ +} while (0) + +#define UTIL_THROW_IF(Condition, Exception, Modify) \ + UTIL_THROW_IF_ARG(Condition, Exception, , Modify) + +#define UTIL_THROW_IF2(Condition, Modify) \ + UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify) + +// Exception that records errno and adds it to the message. +class ErrnoException : public Exception { + public: + ErrnoException() throw(); + + virtual ~ErrnoException() throw(); + + int Error() const throw() { return errno_; } + + private: + int errno_; +}; + +// file wasn't there, or couldn't be open for some reason +class FileOpenException : public Exception { + public: + FileOpenException() throw() {} + ~FileOpenException() throw() {} +}; + +// Utilities for overflow checking. +class OverflowException : public Exception { + public: + OverflowException() throw(); + ~OverflowException() throw(); +}; + +template inline std::size_t CheckOverflowInternal(uint64_t value) { + UTIL_THROW_IF(value > static_cast(std::numeric_limits::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code."); + return value; +} + +template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) { + return value; +} + +inline std::size_t CheckOverflow(uint64_t value) { + return CheckOverflowInternal(value); +} + +#if defined(_WIN32) || defined(_WIN64) +/* Thrown for Windows specific operations. */ +class WindowsException : public Exception { + public: + WindowsException() throw(); + ~WindowsException() throw(); +}; +#endif + +} // namespace util diff --git a/src/marian.h b/src/marian.h index 18d561c4..b8320d91 100644 --- a/src/marian.h +++ b/src/marian.h @@ -5,211 +5,111 @@ #include #include -#include +#include "exception.h" +#include "cudnn_tensor.h" namespace marian { -typedef float Tensor; // Now do this for cuDNN tensors! -struct Chainable; - -boost::pool<> p(sizeof(char)); -std::vector stack; - -struct Chainable { +template +struct Chainable : public std::enable_shared_from_this> { Chainable() { } virtual ~Chainable() { } - - virtual void chain() { } + virtual void forward() { } + virtual void backward() { } virtual void init_dependent() { } virtual void set_zero_adjoint() { } - - static inline void* operator new(size_t nbytes) { - // thread_local variable - return p.ordered_malloc(nbytes); - } + + virtual DataType val() = 0; + virtual DataType grad() = 0; }; -class Vimpl : public Chainable { +typedef std::vector*> ChainableStack; +typedef std::shared_ptr> ChainPtr; + +ChainableStack stack; + +class Node : public Chainable { public: - Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} { + Node(const Tensor t) : val_(t) { + //std::cerr << "Putting node with tensor " << t.id() << " on stack" << std::endl; stack.push_back(this); } - ~Vimpl() {}; + virtual ~Node() {}; - virtual void init_dependent() { adj_ = 1; } - virtual void set_zero_adjoint() { adj_ = 0; } + virtual void init_dependent() { + if(adj_) { + adj_.set(1); + } + else { + adj_ = Tensor(val_.shape(), 1); + } + } - const Tensor& val() const { return val_; }; - Tensor& grad() { return adj_; }; + virtual void set_zero_adjoint() { + if(adj_) { + adj_.set(0); + } + else { + adj_ = Tensor(val_.shape(), 0); + } + } + + virtual Tensor val() { return val_; }; + virtual Tensor grad() { return adj_; }; protected: - const Tensor val_; - Tensor adj_; + Tensor val_; + Tensor adj_; }; -typedef Vimpl* VimplPtr; - -static void set_zero_all_adjoints() { - for(auto&& v : stack) - v->set_zero_adjoint(); -} - -static void grad(Chainable* v) { - typedef std::vector::reverse_iterator It; - v->init_dependent(); - for(It it = stack.rbegin(); it != stack.rend(); ++it) { - (*it)->chain(); - } -} - class Var { public: - Var() : vimpl_{nullptr} {} - Var(const Tensor& t) : vimpl_{new Vimpl{t}} {} - Var(const VimplPtr& vimpl) : vimpl_{vimpl} {} + Var() : pimpl_(nullptr) {} + Var(const Tensor t) : pimpl_(new Node(t)) {} + Var(const Tensor::value_type v) : pimpl_(new Node(Tensor(v))) {} + Var(const ChainPtr chainable) : pimpl_(chainable) {} + Var(Chainable* chainable) : pimpl_(chainable) {} - const Tensor& val() const { - return vimpl_->val(); + Tensor val() { + return pimpl_->val(); } - Tensor& grad() { - return vimpl_->grad(); + Tensor grad() { + return pimpl_->grad(); } - VimplPtr vimpl() const { - return vimpl_; + ChainPtr pimpl() { + return pimpl_; } - void calc_gradients() { - marian::grad(vimpl_); + void forward() { + UTIL_THROW_IF2(pimpl_.get() != stack.back(), + "Trying to call forward on non-root of computation graph"); + + for(auto&& v : stack) + v->forward(); + } + + void backward() { + UTIL_THROW_IF2(pimpl_.get() != stack.back(), + "Trying to call backward on non-root of computation graph"); + + for(auto&& v : stack) + v->set_zero_adjoint(); + + typedef ChainableStack::reverse_iterator It; + pimpl_->init_dependent(); + for(It it = stack.rbegin(); it != stack.rend(); ++it) + (*it)->backward(); + } + + operator ChainPtr() { + return pimpl_; } private: - VimplPtr vimpl_; + ChainPtr pimpl_; }; -/////////////////////////////////////////////////// - -struct OpVimpl : public Vimpl { - OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { } - - VimplPtr a_; -}; - - -struct LogVimpl : public OpVimpl { - LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { } - - void chain() { - a_->grad() += adj_ / a_->val(); - } -}; - -inline Var log(const Var& a) { - return Var(VimplPtr(new LogVimpl(a.vimpl()))); -} - -struct ExpVimpl : public OpVimpl { - ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { } - - void chain() { - a_->grad() += adj_ * std::exp(a_->val()); - } -}; - -inline Var exp(const Var& a) { - return Var(VimplPtr(new ExpVimpl(a.vimpl()))); -} - -struct NegVimpl : public OpVimpl { - NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { } - - void chain() { - a_->grad() -= adj_; - } -}; - -inline Var operator-(const Var& a) { - return Var(VimplPtr(new NegVimpl(a.vimpl()))); -} - -// @TODO: take care of large exponents -struct SigmaVimpl : public OpVimpl { - SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { } - - void chain() { - Tensor l = 1.f / (1.f + std::exp(-a_->val())); - a_->grad() += adj_ * l * (1 - l); - } -}; - -inline Var sigma(const Var& a) { - return Var(VimplPtr(new SigmaVimpl(a.vimpl()))); -} - -/////////////////////////////////////////////////// - - -struct OpVimplVV : public Vimpl { - VimplPtr a_; - VimplPtr b_; - - OpVimplVV(Tensor t, VimplPtr a, VimplPtr b) - : Vimpl(t), a_(a), b_(b) { } -}; - -struct PlusVimplVV : public OpVimplVV { - PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { } - - void chain() { - a_->grad() += adj_; - b_->grad() += adj_; - } -}; - -inline Var operator+(const Var& a, const Var& b) { - return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl()))); -} - -struct MinusVimplVV : public OpVimplVV { - MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { } - - void chain() { - a_->grad() -= adj_; - b_->grad() -= adj_; - } -}; - -inline Var operator-(const Var& a, const Var& b) { - return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl()))); -} - -struct MultVimplVV : public OpVimplVV { - MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { } - - void chain() { - a_->grad() += adj_ * b_->val(); - b_->grad() += adj_ * a_->val(); - } -}; - -inline Var operator*(const Var& a, const Var& b) { - return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl()))); -} - -struct DivVimplVV : public OpVimplVV { - DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { } - - void chain() { - a_->grad() += adj_ / b_->val(); - b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val())); - } -}; - -inline Var operator/(const Var& a, const Var& b) { - return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl()))); -} - - } \ No newline at end of file diff --git a/src/operators.h b/src/operators.h new file mode 100644 index 00000000..340e5188 --- /dev/null +++ b/src/operators.h @@ -0,0 +1,370 @@ +#pragma once + +#include +#include +#include +#include + +#include "marian.h" +#include "cudnn_tensor.h" + +namespace marian { + +/*** Unary operators ***/ + +struct UnaryNodeOp : public Node { + ChainPtr a_; + + UnaryNodeOp(const Tensor t, ChainPtr a) + : Node(t), a_(a) {} +}; + +struct SigmaNodeOp : public UnaryNodeOp { + SigmaNodeOp(ChainPtr a) + : UnaryNodeOp(Tensor(a->val().shape()), a) { } + + void forward() { + Element(_1 = Sigma(_2), + val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)), + a_->grad(), adj_, a_->val()); + } +}; + +inline Var sigma(Var a) { + return Var(new SigmaNodeOp(a)); +} + +struct TanhNodeOp : public UnaryNodeOp { + TanhNodeOp(ChainPtr a) + : UnaryNodeOp(Tensor(a->val().shape()), a) { } + + void forward() { + Element(_1 = Tanh(_2), + val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)), + a_->grad(), adj_, a_->val()); + } +}; + +inline Var tanh(Var a) { + return Var(new TanhNodeOp(a)); +} + +struct LogNodeOp : public UnaryNodeOp { + LogNodeOp(ChainPtr a) + : UnaryNodeOp(Tensor(a->val().shape()), a) { } + + void forward() { + Element(_1 = Log(_2), val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * 1.f / _3, + a_->grad(), adj_, a_->val()); + } +}; + +inline Var log(Var a) { + return Var(new LogNodeOp(a)); +}; + +struct ExpNodeOp : public UnaryNodeOp { + ExpNodeOp(ChainPtr a) + : UnaryNodeOp(Tensor(a->val().shape()), a) { } + + void forward() { + Element(_1 = Exp(_2), val_, a_->val()); + } + + void backward() { + Element(_1 += _2 * Exp(_3), + a_->grad(), adj_, a_->val()); + } +}; + +inline Var exp(Var a) { + return Var(new ExpNodeOp(a)); +}; + +struct NegNodeOp : public UnaryNodeOp { + NegNodeOp(ChainPtr a) + : UnaryNodeOp(Tensor(a->val().shape()), a) { } + + void forward() { + Element(_1 = -_2, val_, a_->val()); + } + + void backward() { + Element(_1 += -_2, a_->grad(), adj_); + } +}; + +inline Var operator-(Var a) { + return Var(new NegNodeOp(a)); +}; + +/******************************************************/ + +struct BinaryNodeOp : public Node { + ChainPtr a_; + ChainPtr b_; + + BinaryNodeOp(const Tensor t, ChainPtr a, ChainPtr b) + : Node(t), a_(a), b_(b) {} +}; + +/*** Matrix Product ***/ + +struct DotNodeOp : public BinaryNodeOp { + DotNodeOp(ChainPtr a, ChainPtr b) : BinaryNodeOp(Tensor(shape(a, b)), a, b) { } + + Shape shape(ChainPtr a, ChainPtr b) { + UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0], + "matrix product requires dimensions to match"); + Shape shape1 = a->val().shape(); + Shape shape2 = b->val().shape(); + shape1[1] = shape2[1]; + return shape1; + } + + void forward() { + // C = A*B + Prod(val_, a_->val(), b_->val(), false, false); + } + + void backward() { + // D is the adjoint, the matrix of derivatives + // df/dA += D*B.T + // df/dB += A.T*D + // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C + // to sum gradients from different graph parts + Prod(a_->grad(), adj_, b_->val(), false, true, 1.0); + Prod(b_->grad(), a_->val(), adj_, true, false, 1.0); + } +}; + +inline Var dot(Var a, Var b) { + return Var(new DotNodeOp(a, b)); +} + +/******************************************************/ + +Var broadcast(Shape shape, Var a) { + if(a.val().shape() == shape) { + return a; + } + else { + size_t dimsA = a.val().shape().size(); + size_t dimsB = shape.size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensor and shape have different number of dimensions"); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = a.val().shape()[i]; + int dimB = shape[i]; + bool broadcastable = (dimA == dimB || dimA == 1); + UTIL_THROW_IF2(!broadcastable, + "Cannot broadcast tensor dimension " + << dimA << " to " << dimB); + if(dimA == 1 && dimB > 1) { + std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl; + if(i == 0) { + Var one = Tensor({shape[0], 1}, 1); + a = dot(one, a); + } + else if(i == 1) { + Var one = Tensor({1, shape[1]}, 1); + a = dot(a, one); + } + else { + UTIL_THROW2("Not inplemented"); + } + } + } + return a; + } +} + +struct BroadcastingNodeOp : public BinaryNodeOp { + BroadcastingNodeOp(Var a, Var b) + : BroadcastingNodeOp(Tensor(shape(a ,b)), broadcast(shape(a ,b), a), broadcast(shape(a ,b), b)) {} + + static Shape shape(ChainPtr a, ChainPtr b) { + size_t dimsA = a->val().shape().size(); + size_t dimsB = b->val().shape().size(); + UTIL_THROW_IF2(dimsA != dimsB, + "Tensors have different numbers of dimensions"); + Shape shape(dimsA); + for(size_t i = 0; i < dimsA; ++i) { + int dimA = a->val().shape()[i]; + int dimB = b->val().shape()[i]; + bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1); + UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise " + << "operation cannot be broadcasted: " << dimA << " != " << dimB); + shape[i] = std::max(dimA, dimB); + } + return shape; + } + + private: + BroadcastingNodeOp(const Tensor t, ChainPtr a, ChainPtr b) + : BinaryNodeOp(t, a, b) {} +}; + +/*** Binary arithmetic ***/ + +/*** Plus ***/ + +struct PlusNodeOp : public BroadcastingNodeOp { + PlusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } + + void forward() { + Element(_1 = _2 + _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2, + a_->grad(), adj_); + Element(_1 += _2, + b_->grad(), adj_); + } +}; + +inline Var operator+(Var a, Var b) { + return Var(new PlusNodeOp(a, b)); +} + +/*** Minus ***/ + +struct MinusNodeOp : public BroadcastingNodeOp { + MinusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } + + void forward() { + Element(_1 = _2 - _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2, + a_->grad(), adj_); + Element(_1 -= _2, + b_->grad(), adj_); + } +}; + +inline Var operator-(Var a, Var b) { + return Var(new MinusNodeOp(a, b)); +} + +/*** Mult ***/ + +struct MultNodeOp : public BroadcastingNodeOp { + MultNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } + + void forward() { + Element(_1 = _2 * _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2 * _3, + a_->grad(), adj_, b_->val()); + Element(_1 += _2 * _3, + b_->grad(), adj_, a_->val()); + } +}; + +inline Var operator*(Var a, Var b) { + return Var(new MultNodeOp(a, b)); +} + +/*** Division ***/ + +struct DivNodeOp : public BroadcastingNodeOp { + DivNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { } + + void forward() { + Element(_1 = _2 / _3, + val_, a_->val(), b_->val()); + } + + void backward() { + Element(_1 += _2 * 1.0f / _3, + a_->grad(), adj_, b_->val()); + Element(_1 -= _2 * _3 / (_4 * _4), + b_->grad(), adj_, a_->val(), b_->val()); + } +}; + +inline Var operator/(Var a, Var b) { + return Var(new DivNodeOp(a, b)); +} + + +/*** Reductions ***/ + +enum Axis { undef, axis0, axis1, axis2, axis3 }; + +// inefficient +inline Var sum(Var a, Axis axis = Axis::undef) { + if(axis == Axis::axis0) { + int rows = a.val().shape()[0]; + int cols = a.val().shape()[1]; + Var one = Tensor({1, rows}, 1); + return dot(one, a); + } + else if(axis == Axis::axis1) { + int rows = a.val().shape()[0]; + int cols = a.val().shape()[1]; + Var one = Tensor({cols, 1}, 1); + return dot(a, one); + } + else if(axis == Axis::axis2) { + UTIL_THROW2("Not inplemented"); + } + else if(axis == Axis::axis3) { + UTIL_THROW2("Not inplemented"); + } + return sum(sum(a, Axis::axis0), Axis::axis1); +} + +// inefficient +inline Var softmax(Var a, Axis axis = Axis::undef) { + Var e = exp(a); + return e / sum(e, axis); +} + +// inefficient +inline Var mean(Var a, Axis axis = Axis::undef) { + switch (axis) { + case Axis::axis0: + return sum(a, axis) / a.val().shape()[0]; + case Axis::axis1: + return sum(a, axis) / a.val().shape()[1]; + case Axis::axis2: + UTIL_THROW2("Not implemented"); + case Axis::axis3: + UTIL_THROW2("Not implemented"); + case Axis::undef: + default: + return sum(a) / a.val().size(); + } +} + +// FAKE +inline Var input(const std::string& name, Var v) { + return v; +} + +inline Var forsave(const std::string& name, Var v) { + return v; +} + +} \ No newline at end of file diff --git a/src/tensor.h b/src/tensor.h new file mode 100644 index 00000000..932278a1 --- /dev/null +++ b/src/tensor.h @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include +#include + +namespace marian { + +class TensorImpl { + public: + typedef float value_type; + + TensorImpl(size_t size, value_type value) + : data_(size, value), tno_(tensorCounter++) + { + std::cerr << "Allocating tensor " << tno_ << std::endl; + } + + TensorImpl(const TensorImpl& t) + : data_(t.data_.begin(), t.data_.end()) + { + std::cerr << "Copying tensor " << tno_ << std::endl; + } + + ~TensorImpl() { + std::cerr << "Destroying tensor " << tno_ << std::endl; + } + + size_t size() const { + return data_.size(); + } + + value_type* data() { + return data_.data(); + } + + const value_type* data() const { + return data_.data(); + } + + size_t id() const { + return tno_; + } + + void set(value_type value) { + std::fill(data_.begin(), data_.end(), value); + } + + private: + std::vector data_; + size_t tno_; + + static size_t tensorCounter; +}; + +size_t TensorImpl::tensorCounter = 0; + +class Tensor { + public: + typedef TensorImpl::value_type value_type; + + Tensor(size_t size, float value) + : pimpl_(new TensorImpl(size, value)) {} + + Tensor() {} + + ~Tensor() {} + + size_t size() const { + return pimpl_->size(); + } + + float* data() { + return pimpl_->data(); + } + + const float* data() const { + return pimpl_->data(); + } + + void set(float value) { + pimpl_->set(value); + } + + size_t id() const { + return pimpl_->id(); + } + + private: + std::shared_ptr pimpl_; +}; + +Tensor operator+(const Tensor a, const Tensor b) { + Tensor c(a.size(), 0); + for(size_t i = 0; i < a.size(); ++i) { + c.data()[i] = a.data()[i] + b.data()[i]; + } + return c; +} + +Tensor operator*(const Tensor a, const Tensor b) { + Tensor c(a.size(), 0); + for(size_t i = 0; i < a.size(); ++i) { + c.data()[i] = a.data()[i] * b.data()[i]; + } + return c; +} + +Tensor operator+=(Tensor a, const Tensor b) { + for(size_t i = 0; i < a.size(); ++i) { + a.data()[i] += b.data()[i]; + } + return a; +} + +} \ No newline at end of file diff --git a/src/test.cpp b/src/test.cpp deleted file mode 100644 index 8d1e380c..00000000 --- a/src/test.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include - -#include -#include - -#include -#include - -#include "marian.h" - -using namespace marian; - -Var layer(size_t max, std::vector& x) { - Var x0 = rand() % 100, x1 = rand() % 100, x2 = rand() % 100; - x = { x0, x1, x2 }; - - Var y = 0.0; - for(int i = 0; i < max; i++) { - Var xi = i; - x.push_back(xi); - y = y + x0 + log(x2) + x1; - for(int j = 0; j < i; ++j) { - y = y + xi; - } - } - - return y; -} - -int main(int argc, char** argv) { - srand(time(NULL)); - - std::vector x1, x2; - Var y1 = layer(10, x1); - Var y2 = layer(rand() % 20 + 1, x2); - - Var y = sigma(log(y1) / log(y2)); - - set_zero_all_adjoints(); - y.calc_gradients(); - - std::cerr << "y1 = " << y1.val() << std::endl; - std::cerr << "y2 = " << y2.val() << std::endl; - std::cerr << "y = " << y.val() << std::endl; - - std::cerr << "dy/dy1 = " << y1.grad() << std::endl; - std::cerr << "dy/dy2 = " << y2.grad() << std::endl; - - for(size_t i = 0; i < x1.size(); ++i) - std::cerr << "x1_" << i << " = " << x1[i].val() << " : dy/dx1_" << i << " = " << x1[i].grad() << std::endl; - for(size_t i = 0; i < x2.size(); ++i) - std::cerr << "x2_" << i << " = " << x2[i].val() << " : dy/dx2_" << i << " = " << x2[i].grad() << std::endl; - -} \ No newline at end of file diff --git a/src/test.cu b/src/test.cu new file mode 100644 index 00000000..db57cdc4 --- /dev/null +++ b/src/test.cu @@ -0,0 +1,71 @@ +#include +#include +#include +#include +#include +#include + +#include "marian.h" +#include "operators.h" + +using namespace marian; + +int main(int argc, char** argv) { + boost::timer::auto_cpu_timer t; + + Var x = input("X", Tensor({4, 2})); + Var y = input("Y", Tensor({4, 2})); + + std::vector vx = { + 0, 0, + 0, 1, + 1, 0, + 1, 1 + }; + + std::vector vy = { + 1, 0, + 1, 0, + 0, 1, + 1, 0 + }; + + thrust::copy(vx.begin(), vx.end(), x.val().begin()); + thrust::copy(vy.begin(), vy.end(), y.val().begin()); + + Var w0 = forsave("W0", uniform(Tensor({2, 2}))); + Var b0 = forsave("b0", uniform(Tensor({1, 2}))); + + Var w1 = forsave("W1", uniform(Tensor({2, 2}))); + Var b1 = forsave("b1", uniform(Tensor({1, 2}))); + + std::vector params = { w0, w1, b0, b1 }; + + Var ry = sigma(dot(x, w0) + b0); + ry = softmax(dot(ry, w1) + b1, Axis::axis1); + Var cost = -mean(sum(y * log(ry), Axis::axis1), Axis::axis0); + + float alpha = 0.1; + for(size_t i = 0; i < 30000; ++i) { + cost.forward(); + + if(i % 100 == 0) { + for(size_t j = 0; j < 4; ++j) { + std::cerr << ry.val()[j*2] << std::endl; + } + std::cerr << i << " ct: " << cost.val()[0] << std::endl; + // alpha = alpha * 0.9; + } + + cost.backward(); + for(auto p : params) { + //std::cerr << p.grad()[0] << std::endl; + auto update = + _1 -= alpha * _2; + + Element(update, p.val(), p.grad()); + } + } + + return 0; +} \ No newline at end of file diff --git a/src/thrust_functions.h b/src/thrust_functions.h new file mode 100644 index 00000000..a3013423 --- /dev/null +++ b/src/thrust_functions.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include + +namespace thrust +{ + namespace detail + { + namespace functional + { + + // Ugly hacks, but it seems this is neccessary. + __host__ __device__ + float expf2(float x) { + float clip = 16; + if(x > clip) + x = clip; + if(x < -clip) + x = -clip; + return expf(x); + } + + __host__ __device__ + float logf2(float x) { + if(x < 10e-10) + x = 10e-10; + return logf(x); + } + + template + struct unary_exp : public thrust::unary_function { + __host__ __device__ + T operator()(const T &x) const { return expf2(x); } + }; + + template + __host__ __device__ + actor, actor>> + Exp(const actor &_1) { + return compose(unary_operator(), _1); + } + + template + struct unary_log : public thrust::unary_function { + __host__ __device__ + T operator()(const T &x) const { return logf2(x); } + }; + + template + __host__ __device__ + actor, actor>> + Log(const actor &_1) { + return compose(unary_operator(), _1); + } + + template + struct unary_sigma : public thrust::unary_function { + __host__ __device__ + T operator()(const T &x) const { return 1.0 / (1.0 + expf2(-x)); } + }; + + template + __host__ __device__ + actor, actor>> + Sigma(const actor &_1) { + return compose(unary_operator(), _1); + } + + template + struct unary_tanh : public thrust::unary_function { + __host__ __device__ + T operator()(const T &x) const { return tanhf(x); } + }; + + template + __host__ __device__ + actor, actor>> + Tanh(const actor &_1) { + return compose(unary_operator(), _1); + } + + template + __host__ __device__ + actor, actor, actor>> + Max(const actor &_1, const actor &_2) { + return compose(binary_operator(), + make_actor(_1), + make_actor(_2)); + } + } + } +} \ No newline at end of file