working basic training

This commit is contained in:
Marcin Junczys-Dowmunt 2016-05-07 22:56:23 +02:00
parent 4291c918ae
commit 7fd950fbda
12 changed files with 1467 additions and 229 deletions

36
.gitignore vendored Normal file
View File

@ -0,0 +1,36 @@
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
# python compiled files
*.pyc
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# Temporaty files created by editors
.*.sw*
build

22
CMakeLists.txt Normal file
View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.5.1)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(marian CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
include_directories(${amunn_SOURCE_DIR})
find_package(CUDA REQUIRED)
find_package(Boost COMPONENTS system timer)
if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS})
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
else(Boost_FOUND)
message(SEND_ERROR "Cannot find Boost libraries. Terminating." )
endif(Boost_FOUND)
include_directories(${marian_SOURCE_DIR}/src)
add_subdirectory(src)

18
src/CMakeLists.txt Normal file
View File

@ -0,0 +1,18 @@
include_directories(.)
add_library(libcommon OBJECT
exception.cpp
)
cuda_add_executable(
marian
test.cu
$<TARGET_OBJECTS:libcommon>
)
foreach(exec marian)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(exec)

400
src/cudnn_tensor.h Normal file
View File

@ -0,0 +1,400 @@
#pragma once
#include <memory>
#include <functional>
#include <vector>
#include <cmath>
#include <cudnn.h>
#include <cublas_v2.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include "exception.h"
#include "thrust_functions.h"
namespace marian {
struct Handles {
cudnnHandle_t cudnnHandle;
cublasHandle_t cublasHandle;
cudnnOpTensorDescriptor_t add;
Handles() {
cudnnCreate(&cudnnHandle);
cublasCreate(&cublasHandle);
cudnnCreateOpTensorDescriptor(&add);
cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN);
}
~Handles() {
cudnnDestroy(cudnnHandle);
cublasDestroy(cublasHandle);
cudnnDestroyOpTensorDescriptor(add);
}
};
Handles handles;
typedef std::vector<int> Shape;
template<class Float>
class TensorImpl {
private:
Shape shape_;
thrust::device_vector<Float> data_;
cudnnTensorDescriptor_t desc_;
size_t tno_;
static size_t tensorCounter;
cudnnDataType_t dataType() {
switch(sizeof(Float)) {
case 2: return CUDNN_DATA_HALF;
case 8: return CUDNN_DATA_DOUBLE;
default: return CUDNN_DATA_FLOAT;
}
}
public:
typedef Float value_type;
TensorImpl(const Shape& shape, value_type value = 0)
: shape_(shape), tno_(tensorCounter++)
{
// @TODO:
UTIL_THROW_IF2(shape_.size() != 2,
"For now, only 2D Tensors, will be fixed later.");
UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
"Wrong number of dimensions: " << shape_.size());
int size = std::accumulate(shape_.begin(), shape_.end(),
1, std::multiplies<int>());
data_.resize(size, value);
cudnnCreateTensorDescriptor(&desc_);
switch (shape_.size()) {
case 1:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], 1, 1, 1); break;
case 2:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], 1, 1); break;
case 3:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], shape_[2], 1); break;
case 4:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], shape_[2], shape_[3]); break;
}
}
TensorImpl(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = delete;
~TensorImpl() {
cudnnDestroyTensorDescriptor(desc_);
}
value_type operator[](size_t i) const {
return data_[i];
}
auto begin() -> decltype( data_.begin() ) {
return data_.begin();
}
auto begin() const -> decltype( data_.begin() ) {
return data_.begin();
}
auto end() -> decltype( data_.end() ) {
return data_.end();
}
auto end() const -> decltype( data_.end() ) {
return data_.end();
}
const Shape& shape() const {
return shape_;
}
size_t size() const {
return data_.size();
}
value_type* data() {
return thrust::raw_pointer_cast(data_.data());
}
cudnnTensorDescriptor_t desc() const {
return desc_;
}
size_t id() const {
return tno_;
}
void set(value_type value) {
thrust::fill(data_.begin(), data_.end(), value);
}
};
template <typename Type>
size_t TensorImpl<Type>::tensorCounter = 0;
class Tensor {
private:
std::shared_ptr<TensorImpl<float>> pimpl_;
public:
typedef TensorImpl<float>::value_type value_type;
Tensor(const Shape& shape, value_type value = 0)
: pimpl_(new TensorImpl<value_type>(shape, value)) {}
// Single value with broadcasting super powers. Might be
// worth getting rid of this performance-wise, but is saves
// so much typing when defining operators.
Tensor(value_type value)
: pimpl_(new TensorImpl<value_type>({1, 1}, value)) {}
Tensor() {}
~Tensor() {}
value_type operator[](size_t i) const {
return (*pimpl_)[i];
}
size_t size() const {
return pimpl_->size();
}
value_type* data() {
return pimpl_->data();
}
const value_type* data() const {
return pimpl_->data();
}
auto begin() -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto begin() const -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto end() -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto end() const -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
const Shape& shape() const {
return pimpl_->shape();
}
cudnnTensorDescriptor_t desc() const {
return pimpl_->desc();
}
void set(value_type value) {
pimpl_->set(value);
}
size_t id() const {
return pimpl_->id();
}
operator bool() {
return pimpl_ != nullptr;
}
};
Tensor uniform(Tensor t, float a=-0.1, float b=0.1) {
std::vector<float> r(t.size());
for(int i = 0; i < r.size(); i++)
r[i] = (float(rand() % 2000) - 1000.0)/10000.0;
thrust::copy(r.begin(), r.end(), t.begin());
return t;
};
using namespace thrust::placeholders;
#define MAX_THREADS 512
#define MAX_BLOCKS 65535
template <class Functor>
__global__ void gElement(Functor functor, float* out,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i]);;
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
float* out, const float* in,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
float* rowOut = out + j * cols;
const float* rowIn = in + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn[i]);;
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
float* out, const float* in1, const float* in2,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
float* rowOut = out + j * cols;
const float* rowIn1 = in1 + j * cols;
const float* rowIn2 = in2 + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
float* out, const float* in1,
const float* in2, const float* in3,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
float* rowOut = out + j * cols;
const float* rowIn1 = in1 + j * cols;
const float* rowIn2 = in2 + j * cols;
const float* rowIn3 = in3 + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
}
}
}
}
// @TODO add broadcasting
template <class Functor>
void Element(Functor functor, Tensor Out) {
float* d_out = Out.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In) {
float* d_out = Out.data();
const float* d_in = In.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In1, const Tensor In2) {
float* d_out = Out.data();
const float* d_in1 = In1.data();
const float* d_in2 = In2.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In1,
const Tensor In2, const Tensor In3) {
float* d_out = Out.data();
const float* d_in1 = In1.data();
const float* d_in2 = In2.data();
const float* d_in3 = In3.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, float beta) {
float alpha = 1.0;
size_t m = A.shape()[0];
size_t k = A.shape()[1];
if(transA)
std::swap(m, k);
size_t l = B.shape()[0];
size_t n = B.shape()[1];
if(transB)
std::swap(l, n);
size_t lda = A.shape()[1];
size_t ldb = B.shape()[1];
size_t ldc = B.shape()[1];
if(transB)
ldc = B.shape()[0];
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasSgemm(handle, opB, opA,
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
return C;
}
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, float beta = 0) {
return Prod(handles.cublasHandle, C, A, B, transA, transB, beta);
}
}

108
src/exception.cpp Normal file
View File

@ -0,0 +1,108 @@
#include "exception.h"
#ifdef __GXX_RTTI
#include <typeinfo>
#endif
#include <cerrno>
#include <cstring>
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#include <io.h>
#endif
namespace util {
Exception::Exception() throw() {}
Exception::~Exception() throw() {}
Exception::Exception(const Exception& o) throw() {
what_.str(o.what_.str());
}
void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
/* The child class might have set some text, but we want this to come first.
* Another option would be passing this information to the constructor, but
* then child classes would have to accept constructor arguments and pass
* them down.
*/
std::string old_text = what_.str();
what_.str(std::string());
what_ << file << ':' << line;
if (func) what_ << " in " << func << " threw ";
if (child_name) {
what_ << child_name;
} else {
#ifdef __GXX_RTTI
what_ << typeid(this).name();
#else
what_ << "an exception";
#endif
}
if (condition) {
what_ << " because `" << condition << '\'';
}
what_ << ".\n";
what_ << old_text;
}
namespace {
#ifdef __GNUC__
const char *HandleStrerror(int ret, const char *buf) __attribute__ ((unused));
const char *HandleStrerror(const char *ret, const char * /*buf*/) __attribute__ ((unused));
#endif
// At least one of these functions will not be called.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-function"
#endif
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
if (!ret) return buf;
return NULL;
}
// The GNU version.
const char *HandleStrerror(const char *ret, const char * /*buf*/) {
return ret;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
} // namespace
ErrnoException::ErrnoException() throw() : errno_(errno) {
char buf[200];
buf[0] = 0;
#if defined(sun) || defined(_WIN32) || defined(_WIN64)
const char *add = strerror(errno);
#else
const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
#endif
if (add) {
*this << add << ' ';
}
}
ErrnoException::~ErrnoException() throw() {}
OverflowException::OverflowException() throw() {}
OverflowException::~OverflowException() throw() {}
#if defined(_WIN32) || defined(_WIN64)
WindowsException::WindowsException() throw() {
unsigned int last_error = GetLastError();
char error_msg[256] = "";
if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, last_error, LANG_NEUTRAL, error_msg, sizeof(error_msg), NULL)) {
*this << "Windows error " << GetLastError() << " while formatting Windows error " << last_error << ". ";
} else {
*this << "Windows error " << last_error << ": " << error_msg;
}
}
WindowsException::~WindowsException() throw() {}
#endif
} // namespace util

156
src/exception.h Normal file
View File

@ -0,0 +1,156 @@
#pragma once
#include <sstream>
#include <exception>
#include <limits>
#include <string>
#include <stdint.h>
namespace util {
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
class Exception : public std::exception {
public:
Exception() throw();
virtual ~Exception() throw();
Exception(const Exception& o) throw();
const char *what() const throw() { return what_.str().c_str(); }
// For use by the UTIL_THROW macros.
void SetLocation(
const char *file,
unsigned int line,
const char *func,
const char *child_name,
const char *condition);
private:
template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
// This helps restrict operator<< defined below.
template <class T> struct ExceptionTag {
typedef T Identity;
};
std::stringstream what_;
};
/* This implements the normal operator<< for Exception and all its children.
* SFINAE means it only applies to Exception. Think of this as an ersatz
* boost::enable_if.
*/
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
e.what_ << data;
return e;
}
#ifdef __GNUC__
#define UTIL_FUNC_NAME __PRETTY_FUNCTION__
#else
#ifdef _WIN32
#define UTIL_FUNC_NAME __FUNCTION__
#else
#define UTIL_FUNC_NAME NULL
#endif
#endif
/* Create an instance of Exception, add the message Modify, and throw it.
* Modify is appended to the what() message and can contain << for ostream
* operations.
*
* do .. while kludge to swallow trailing ; character
* http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .
* Arg can be a constructor argument to the exception.
*/
#define UTIL_THROW_BACKEND(Condition, Exception, Arg, Modify) do { \
Exception UTIL_e Arg; \
UTIL_e.SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, #Exception, Condition); \
UTIL_e << Modify; \
throw UTIL_e; \
} while (0)
#define UTIL_THROW_ARG(Exception, Arg, Modify) \
UTIL_THROW_BACKEND(NULL, Exception, Arg, Modify)
#define UTIL_THROW(Exception, Modify) \
UTIL_THROW_BACKEND(NULL, Exception, , Modify);
#define UTIL_THROW2(Modify) \
UTIL_THROW_BACKEND(NULL, util::Exception, , Modify);
#if __GNUC__ >= 3
#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0)
#else
#define UTIL_UNLIKELY(x) (x)
#endif
#if __GNUC__ >= 3
#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1)
#else
#define UTIL_LIKELY(x) (x)
#endif
#define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \
if (UTIL_UNLIKELY(Condition)) { \
UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \
} \
} while (0)
#define UTIL_THROW_IF(Condition, Exception, Modify) \
UTIL_THROW_IF_ARG(Condition, Exception, , Modify)
#define UTIL_THROW_IF2(Condition, Modify) \
UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify)
// Exception that records errno and adds it to the message.
class ErrnoException : public Exception {
public:
ErrnoException() throw();
virtual ~ErrnoException() throw();
int Error() const throw() { return errno_; }
private:
int errno_;
};
// file wasn't there, or couldn't be open for some reason
class FileOpenException : public Exception {
public:
FileOpenException() throw() {}
~FileOpenException() throw() {}
};
// Utilities for overflow checking.
class OverflowException : public Exception {
public:
OverflowException() throw();
~OverflowException() throw();
};
template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
return value;
}
template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
return value;
}
inline std::size_t CheckOverflow(uint64_t value) {
return CheckOverflowInternal<sizeof(std::size_t)>(value);
}
#if defined(_WIN32) || defined(_WIN64)
/* Thrown for Windows specific operations. */
class WindowsException : public Exception {
public:
WindowsException() throw();
~WindowsException() throw();
};
#endif
} // namespace util

View File

@ -5,211 +5,111 @@
#include <vector>
#include <cmath>
#include <boost/pool/pool.hpp>
#include "exception.h"
#include "cudnn_tensor.h"
namespace marian {
typedef float Tensor; // Now do this for cuDNN tensors!
struct Chainable;
boost::pool<> p(sizeof(char));
std::vector<Chainable*> stack;
struct Chainable {
template <class DataType>
struct Chainable : public std::enable_shared_from_this<Chainable<DataType>> {
Chainable() { }
virtual ~Chainable() { }
virtual void chain() { }
virtual void forward() { }
virtual void backward() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }
static inline void* operator new(size_t nbytes) {
// thread_local variable
return p.ordered_malloc(nbytes);
}
virtual DataType val() = 0;
virtual DataType grad() = 0;
};
class Vimpl : public Chainable {
typedef std::vector<Chainable<Tensor>*> ChainableStack;
typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
ChainableStack stack;
class Node : public Chainable<Tensor> {
public:
Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} {
Node(const Tensor t) : val_(t) {
//std::cerr << "Putting node with tensor " << t.id() << " on stack" << std::endl;
stack.push_back(this);
}
~Vimpl() {};
virtual ~Node() {};
virtual void init_dependent() { adj_ = 1; }
virtual void set_zero_adjoint() { adj_ = 0; }
virtual void init_dependent() {
if(adj_) {
adj_.set(1);
}
else {
adj_ = Tensor(val_.shape(), 1);
}
}
const Tensor& val() const { return val_; };
Tensor& grad() { return adj_; };
virtual void set_zero_adjoint() {
if(adj_) {
adj_.set(0);
}
else {
adj_ = Tensor(val_.shape(), 0);
}
}
virtual Tensor val() { return val_; };
virtual Tensor grad() { return adj_; };
protected:
const Tensor val_;
Tensor adj_;
Tensor val_;
Tensor adj_;
};
typedef Vimpl* VimplPtr;
static void set_zero_all_adjoints() {
for(auto&& v : stack)
v->set_zero_adjoint();
}
static void grad(Chainable* v) {
typedef std::vector<Chainable*>::reverse_iterator It;
v->init_dependent();
for(It it = stack.rbegin(); it != stack.rend(); ++it) {
(*it)->chain();
}
}
class Var {
public:
Var() : vimpl_{nullptr} {}
Var(const Tensor& t) : vimpl_{new Vimpl{t}} {}
Var(const VimplPtr& vimpl) : vimpl_{vimpl} {}
Var() : pimpl_(nullptr) {}
Var(const Tensor t) : pimpl_(new Node(t)) {}
Var(const Tensor::value_type v) : pimpl_(new Node(Tensor(v))) {}
Var(const ChainPtr chainable) : pimpl_(chainable) {}
Var(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
const Tensor& val() const {
return vimpl_->val();
Tensor val() {
return pimpl_->val();
}
Tensor& grad() {
return vimpl_->grad();
Tensor grad() {
return pimpl_->grad();
}
VimplPtr vimpl() const {
return vimpl_;
ChainPtr pimpl() {
return pimpl_;
}
void calc_gradients() {
marian::grad(vimpl_);
void forward() {
UTIL_THROW_IF2(pimpl_.get() != stack.back(),
"Trying to call forward on non-root of computation graph");
for(auto&& v : stack)
v->forward();
}
void backward() {
UTIL_THROW_IF2(pimpl_.get() != stack.back(),
"Trying to call backward on non-root of computation graph");
for(auto&& v : stack)
v->set_zero_adjoint();
typedef ChainableStack::reverse_iterator It;
pimpl_->init_dependent();
for(It it = stack.rbegin(); it != stack.rend(); ++it)
(*it)->backward();
}
operator ChainPtr() {
return pimpl_;
}
private:
VimplPtr vimpl_;
ChainPtr pimpl_;
};
///////////////////////////////////////////////////
struct OpVimpl : public Vimpl {
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
VimplPtr a_;
};
struct LogVimpl : public OpVimpl {
LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { }
void chain() {
a_->grad() += adj_ / a_->val();
}
};
inline Var log(const Var& a) {
return Var(VimplPtr(new LogVimpl(a.vimpl())));
}
struct ExpVimpl : public OpVimpl {
ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
void chain() {
a_->grad() += adj_ * std::exp(a_->val());
}
};
inline Var exp(const Var& a) {
return Var(VimplPtr(new ExpVimpl(a.vimpl())));
}
struct NegVimpl : public OpVimpl {
NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
void chain() {
a_->grad() -= adj_;
}
};
inline Var operator-(const Var& a) {
return Var(VimplPtr(new NegVimpl(a.vimpl())));
}
// @TODO: take care of large exponents
struct SigmaVimpl : public OpVimpl {
SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
void chain() {
Tensor l = 1.f / (1.f + std::exp(-a_->val()));
a_->grad() += adj_ * l * (1 - l);
}
};
inline Var sigma(const Var& a) {
return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
}
///////////////////////////////////////////////////
struct OpVimplVV : public Vimpl {
VimplPtr a_;
VimplPtr b_;
OpVimplVV(Tensor t, VimplPtr a, VimplPtr b)
: Vimpl(t), a_(a), b_(b) { }
};
struct PlusVimplVV : public OpVimplVV {
PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { }
void chain() {
a_->grad() += adj_;
b_->grad() += adj_;
}
};
inline Var operator+(const Var& a, const Var& b) {
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
}
struct MinusVimplVV : public OpVimplVV {
MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
void chain() {
a_->grad() -= adj_;
b_->grad() -= adj_;
}
};
inline Var operator-(const Var& a, const Var& b) {
return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
}
struct MultVimplVV : public OpVimplVV {
MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
void chain() {
a_->grad() += adj_ * b_->val();
b_->grad() += adj_ * a_->val();
}
};
inline Var operator*(const Var& a, const Var& b) {
return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
}
struct DivVimplVV : public OpVimplVV {
DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
void chain() {
a_->grad() += adj_ / b_->val();
b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
}
};
inline Var operator/(const Var& a, const Var& b) {
return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
}
}

370
src/operators.h Normal file
View File

@ -0,0 +1,370 @@
#pragma once
#include <memory>
#include <functional>
#include <vector>
#include <cmath>
#include "marian.h"
#include "cudnn_tensor.h"
namespace marian {
/*** Unary operators ***/
struct UnaryNodeOp : public Node {
ChainPtr a_;
UnaryNodeOp(const Tensor t, ChainPtr a)
: Node(t), a_(a) {}
};
struct SigmaNodeOp : public UnaryNodeOp {
SigmaNodeOp(ChainPtr a)
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
void forward() {
Element(_1 = Sigma(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
a_->grad(), adj_, a_->val());
}
};
inline Var sigma(Var a) {
return Var(new SigmaNodeOp(a));
}
struct TanhNodeOp : public UnaryNodeOp {
TanhNodeOp(ChainPtr a)
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
void forward() {
Element(_1 = Tanh(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
a_->grad(), adj_, a_->val());
}
};
inline Var tanh(Var a) {
return Var(new TanhNodeOp(a));
}
struct LogNodeOp : public UnaryNodeOp {
LogNodeOp(ChainPtr a)
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
void forward() {
Element(_1 = Log(_2), val_, a_->val());
}
void backward() {
Element(_1 += _2 * 1.f / _3,
a_->grad(), adj_, a_->val());
}
};
inline Var log(Var a) {
return Var(new LogNodeOp(a));
};
struct ExpNodeOp : public UnaryNodeOp {
ExpNodeOp(ChainPtr a)
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
void forward() {
Element(_1 = Exp(_2), val_, a_->val());
}
void backward() {
Element(_1 += _2 * Exp(_3),
a_->grad(), adj_, a_->val());
}
};
inline Var exp(Var a) {
return Var(new ExpNodeOp(a));
};
struct NegNodeOp : public UnaryNodeOp {
NegNodeOp(ChainPtr a)
: UnaryNodeOp(Tensor(a->val().shape()), a) { }
void forward() {
Element(_1 = -_2, val_, a_->val());
}
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
}
};
inline Var operator-(Var a) {
return Var(new NegNodeOp(a));
};
/******************************************************/
struct BinaryNodeOp : public Node {
ChainPtr a_;
ChainPtr b_;
BinaryNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
: Node(t), a_(a), b_(b) {}
};
/*** Matrix Product ***/
struct DotNodeOp : public BinaryNodeOp {
DotNodeOp(ChainPtr a, ChainPtr b) : BinaryNodeOp(Tensor(shape(a, b)), a, b) { }
Shape shape(ChainPtr a, ChainPtr b) {
UTIL_THROW_IF2(a->val().shape()[1] != b->val().shape()[0],
"matrix product requires dimensions to match");
Shape shape1 = a->val().shape();
Shape shape2 = b->val().shape();
shape1[1] = shape2[1];
return shape1;
}
void forward() {
// C = A*B
Prod(val_, a_->val(), b_->val(), false, false);
}
void backward() {
// D is the adjoint, the matrix of derivatives
// df/dA += D*B.T
// df/dB += A.T*D
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
// to sum gradients from different graph parts
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
}
};
inline Var dot(Var a, Var b) {
return Var(new DotNodeOp(a, b));
}
/******************************************************/
Var broadcast(Shape shape, Var a) {
if(a.val().shape() == shape) {
return a;
}
else {
size_t dimsA = a.val().shape().size();
size_t dimsB = shape.size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensor and shape have different number of dimensions");
for(size_t i = 0; i < dimsA; ++i) {
int dimA = a.val().shape()[i];
int dimB = shape[i];
bool broadcastable = (dimA == dimB || dimA == 1);
UTIL_THROW_IF2(!broadcastable,
"Cannot broadcast tensor dimension "
<< dimA << " to " << dimB);
if(dimA == 1 && dimB > 1) {
std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl;
if(i == 0) {
Var one = Tensor({shape[0], 1}, 1);
a = dot(one, a);
}
else if(i == 1) {
Var one = Tensor({1, shape[1]}, 1);
a = dot(a, one);
}
else {
UTIL_THROW2("Not inplemented");
}
}
}
return a;
}
}
struct BroadcastingNodeOp : public BinaryNodeOp {
BroadcastingNodeOp(Var a, Var b)
: BroadcastingNodeOp(Tensor(shape(a ,b)), broadcast(shape(a ,b), a), broadcast(shape(a ,b), b)) {}
static Shape shape(ChainPtr a, ChainPtr b) {
size_t dimsA = a->val().shape().size();
size_t dimsB = b->val().shape().size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensors have different numbers of dimensions");
Shape shape(dimsA);
for(size_t i = 0; i < dimsA; ++i) {
int dimA = a->val().shape()[i];
int dimB = b->val().shape()[i];
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
shape[i] = std::max(dimA, dimB);
}
return shape;
}
private:
BroadcastingNodeOp(const Tensor t, ChainPtr a, ChainPtr b)
: BinaryNodeOp(t, a, b) {}
};
/*** Binary arithmetic ***/
/*** Plus ***/
struct PlusNodeOp : public BroadcastingNodeOp {
PlusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
void forward() {
Element(_1 = _2 + _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2,
a_->grad(), adj_);
Element(_1 += _2,
b_->grad(), adj_);
}
};
inline Var operator+(Var a, Var b) {
return Var(new PlusNodeOp(a, b));
}
/*** Minus ***/
struct MinusNodeOp : public BroadcastingNodeOp {
MinusNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
void forward() {
Element(_1 = _2 - _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2,
a_->grad(), adj_);
Element(_1 -= _2,
b_->grad(), adj_);
}
};
inline Var operator-(Var a, Var b) {
return Var(new MinusNodeOp(a, b));
}
/*** Mult ***/
struct MultNodeOp : public BroadcastingNodeOp {
MultNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
void forward() {
Element(_1 = _2 * _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * _3,
a_->grad(), adj_, b_->val());
Element(_1 += _2 * _3,
b_->grad(), adj_, a_->val());
}
};
inline Var operator*(Var a, Var b) {
return Var(new MultNodeOp(a, b));
}
/*** Division ***/
struct DivNodeOp : public BroadcastingNodeOp {
DivNodeOp(Var a, Var b) : BroadcastingNodeOp(a, b) { }
void forward() {
Element(_1 = _2 / _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * 1.0f / _3,
a_->grad(), adj_, b_->val());
Element(_1 -= _2 * _3 / (_4 * _4),
b_->grad(), adj_, a_->val(), b_->val());
}
};
inline Var operator/(Var a, Var b) {
return Var(new DivNodeOp(a, b));
}
/*** Reductions ***/
enum Axis { undef, axis0, axis1, axis2, axis3 };
// inefficient
inline Var sum(Var a, Axis axis = Axis::undef) {
if(axis == Axis::axis0) {
int rows = a.val().shape()[0];
int cols = a.val().shape()[1];
Var one = Tensor({1, rows}, 1);
return dot(one, a);
}
else if(axis == Axis::axis1) {
int rows = a.val().shape()[0];
int cols = a.val().shape()[1];
Var one = Tensor({cols, 1}, 1);
return dot(a, one);
}
else if(axis == Axis::axis2) {
UTIL_THROW2("Not inplemented");
}
else if(axis == Axis::axis3) {
UTIL_THROW2("Not inplemented");
}
return sum(sum(a, Axis::axis0), Axis::axis1);
}
// inefficient
inline Var softmax(Var a, Axis axis = Axis::undef) {
Var e = exp(a);
return e / sum(e, axis);
}
// inefficient
inline Var mean(Var a, Axis axis = Axis::undef) {
switch (axis) {
case Axis::axis0:
return sum(a, axis) / a.val().shape()[0];
case Axis::axis1:
return sum(a, axis) / a.val().shape()[1];
case Axis::axis2:
UTIL_THROW2("Not implemented");
case Axis::axis3:
UTIL_THROW2("Not implemented");
case Axis::undef:
default:
return sum(a) / a.val().size();
}
}
// FAKE
inline Var input(const std::string& name, Var v) {
return v;
}
inline Var forsave(const std::string& name, Var v) {
return v;
}
}

117
src/tensor.h Normal file
View File

@ -0,0 +1,117 @@
#pragma once
#include <memory>
#include <functional>
#include <vector>
#include <cmath>
namespace marian {
class TensorImpl {
public:
typedef float value_type;
TensorImpl(size_t size, value_type value)
: data_(size, value), tno_(tensorCounter++)
{
std::cerr << "Allocating tensor " << tno_ << std::endl;
}
TensorImpl(const TensorImpl& t)
: data_(t.data_.begin(), t.data_.end())
{
std::cerr << "Copying tensor " << tno_ << std::endl;
}
~TensorImpl() {
std::cerr << "Destroying tensor " << tno_ << std::endl;
}
size_t size() const {
return data_.size();
}
value_type* data() {
return data_.data();
}
const value_type* data() const {
return data_.data();
}
size_t id() const {
return tno_;
}
void set(value_type value) {
std::fill(data_.begin(), data_.end(), value);
}
private:
std::vector<value_type> data_;
size_t tno_;
static size_t tensorCounter;
};
size_t TensorImpl::tensorCounter = 0;
class Tensor {
public:
typedef TensorImpl::value_type value_type;
Tensor(size_t size, float value)
: pimpl_(new TensorImpl(size, value)) {}
Tensor() {}
~Tensor() {}
size_t size() const {
return pimpl_->size();
}
float* data() {
return pimpl_->data();
}
const float* data() const {
return pimpl_->data();
}
void set(float value) {
pimpl_->set(value);
}
size_t id() const {
return pimpl_->id();
}
private:
std::shared_ptr<TensorImpl> pimpl_;
};
Tensor operator+(const Tensor a, const Tensor b) {
Tensor c(a.size(), 0);
for(size_t i = 0; i < a.size(); ++i) {
c.data()[i] = a.data()[i] + b.data()[i];
}
return c;
}
Tensor operator*(const Tensor a, const Tensor b) {
Tensor c(a.size(), 0);
for(size_t i = 0; i < a.size(); ++i) {
c.data()[i] = a.data()[i] * b.data()[i];
}
return c;
}
Tensor operator+=(Tensor a, const Tensor b) {
for(size_t i = 0; i < a.size(); ++i) {
a.data()[i] += b.data()[i];
}
return a;
}
}

View File

@ -1,55 +0,0 @@
#include <iostream>
#include <ctime>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <cublas_v2.h>
#include <cudnn.h>
#include "marian.h"
using namespace marian;
Var layer(size_t max, std::vector<Var>& x) {
Var x0 = rand() % 100, x1 = rand() % 100, x2 = rand() % 100;
x = { x0, x1, x2 };
Var y = 0.0;
for(int i = 0; i < max; i++) {
Var xi = i;
x.push_back(xi);
y = y + x0 + log(x2) + x1;
for(int j = 0; j < i; ++j) {
y = y + xi;
}
}
return y;
}
int main(int argc, char** argv) {
srand(time(NULL));
std::vector<Var> x1, x2;
Var y1 = layer(10, x1);
Var y2 = layer(rand() % 20 + 1, x2);
Var y = sigma(log(y1) / log(y2));
set_zero_all_adjoints();
y.calc_gradients();
std::cerr << "y1 = " << y1.val() << std::endl;
std::cerr << "y2 = " << y2.val() << std::endl;
std::cerr << "y = " << y.val() << std::endl;
std::cerr << "dy/dy1 = " << y1.grad() << std::endl;
std::cerr << "dy/dy2 = " << y2.grad() << std::endl;
for(size_t i = 0; i < x1.size(); ++i)
std::cerr << "x1_" << i << " = " << x1[i].val() << " : dy/dx1_" << i << " = " << x1[i].grad() << std::endl;
for(size_t i = 0; i < x2.size(); ++i)
std::cerr << "x2_" << i << " = " << x2[i].val() << " : dy/dx2_" << i << " = " << x2[i].grad() << std::endl;
}

71
src/test.cu Normal file
View File

@ -0,0 +1,71 @@
#include <iostream>
#include <ctime>
#include <vector>
#include <algorithm>
#include <random>
#include <boost/timer/timer.hpp>
#include "marian.h"
#include "operators.h"
using namespace marian;
int main(int argc, char** argv) {
boost::timer::auto_cpu_timer t;
Var x = input("X", Tensor({4, 2}));
Var y = input("Y", Tensor({4, 2}));
std::vector<float> vx = {
0, 0,
0, 1,
1, 0,
1, 1
};
std::vector<float> vy = {
1, 0,
1, 0,
0, 1,
1, 0
};
thrust::copy(vx.begin(), vx.end(), x.val().begin());
thrust::copy(vy.begin(), vy.end(), y.val().begin());
Var w0 = forsave("W0", uniform(Tensor({2, 2})));
Var b0 = forsave("b0", uniform(Tensor({1, 2})));
Var w1 = forsave("W1", uniform(Tensor({2, 2})));
Var b1 = forsave("b1", uniform(Tensor({1, 2})));
std::vector<Var> params = { w0, w1, b0, b1 };
Var ry = sigma(dot(x, w0) + b0);
ry = softmax(dot(ry, w1) + b1, Axis::axis1);
Var cost = -mean(sum(y * log(ry), Axis::axis1), Axis::axis0);
float alpha = 0.1;
for(size_t i = 0; i < 30000; ++i) {
cost.forward();
if(i % 100 == 0) {
for(size_t j = 0; j < 4; ++j) {
std::cerr << ry.val()[j*2] << std::endl;
}
std::cerr << i << " ct: " << cost.val()[0] << std::endl;
// alpha = alpha * 0.9;
}
cost.backward();
for(auto p : params) {
//std::cerr << p.grad()[0] << std::endl;
auto update =
_1 -= alpha * _2;
Element(update, p.val(), p.grad());
}
}
return 0;
}

95
src/thrust_functions.h Normal file
View File

@ -0,0 +1,95 @@
#pragma once
#include <cmath>
#include <cublas_v2.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
namespace thrust
{
namespace detail
{
namespace functional
{
// Ugly hacks, but it seems this is neccessary.
__host__ __device__
float expf2(float x) {
float clip = 16;
if(x > clip)
x = clip;
if(x < -clip)
x = -clip;
return expf(x);
}
__host__ __device__
float logf2(float x) {
if(x < 10e-10)
x = 10e-10;
return logf(x);
}
template<typename T>
struct unary_exp : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return expf2(x); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_exp>, actor<Eval>>>
Exp(const actor<Eval> &_1) {
return compose(unary_operator<unary_exp>(), _1);
}
template<typename T>
struct unary_log : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return logf2(x); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_log>, actor<Eval>>>
Log(const actor<Eval> &_1) {
return compose(unary_operator<unary_log>(), _1);
}
template<typename T>
struct unary_sigma : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return 1.0 / (1.0 + expf2(-x)); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_sigma>, actor<Eval>>>
Sigma(const actor<Eval> &_1) {
return compose(unary_operator<unary_sigma>(), _1);
}
template<typename T>
struct unary_tanh : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return tanhf(x); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_tanh>, actor<Eval>>>
Tanh(const actor<Eval> &_1) {
return compose(unary_operator<unary_tanh>(), _1);
}
template<typename T1, typename T2>
__host__ __device__
actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>>
Max(const actor<T1> &_1, const actor<T2> &_2) {
return compose(binary_operator<thrust::maximum>(),
make_actor(_1),
make_actor(_2));
}
}
}
}