This commit is contained in:
Roman Grundkiewicz 2016-09-13 16:13:38 +02:00
commit 142a1e5435
5 changed files with 34 additions and 408 deletions

View File

@ -1,6 +1,9 @@
#include <sstream>
#include "expressions.h"
#include "graph_operators.h"
using namespace std;
namespace marian {
Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
@ -49,5 +52,14 @@ void Expr::backward() {
Expr::operator ChainPtr() {
return pimpl_;
}
std::string Expr::Debug() const
{
stringstream strm;
//const Chainable<Tensor> &ct = *pimpl_;
const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape);
return strm.str();
}
}
}

View File

@ -24,8 +24,10 @@ class Expr {
ChainPtr node();
operator ChainPtr();
std::string Debug() const;
private:
ChainPtr pimpl_;
};
}
}

View File

@ -1,399 +0,0 @@
#pragma once
#include <memory>
#include <functional>
#include <vector>
#include <cmath>
#include <cudnn.h>
#include <cublas_v2.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include "definitions.h"
#include "exception.h"
#include "thrust_functions.h"
namespace marian {
struct Handles {
cudnnHandle_t cudnnHandle;
cublasHandle_t cublasHandle;
cudnnOpTensorDescriptor_t add;
Handles() {
cudnnCreate(&cudnnHandle);
cublasCreate(&cublasHandle);
cudnnCreateOpTensorDescriptor(&add);
cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN);
}
~Handles() {
cudnnDestroy(cudnnHandle);
cublasDestroy(cublasHandle);
cudnnDestroyOpTensorDescriptor(add);
}
};
Handles handles;
typedef std::vector<int> Shape;
template<class Float>
class TensorImpl {
private:
Shape shape_;
thrust::device_vector<Float> data_;
cudnnTensorDescriptor_t desc_;
size_t tno_;
static size_t tensorCounter;
cudnnDataType_t dataType() {
switch(sizeof(Float)) {
case 2: return CUDNN_DATA_HALF;
case 8: return CUDNN_DATA_DOUBLE;
default: return CUDNN_DATA_FLOAT;
}
}
public:
typedef Float value_type;
TensorImpl(const Shape& shape, value_type value = 0)
: shape_(shape), tno_(tensorCounter++)
{
// @TODO:
UTIL_THROW_IF2(shape_.size() != 2,
"For now, only 2D Tensors, will be fixed later.");
UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
"Wrong number of dimensions: " << shape_.size());
std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl;
int size = std::accumulate(shape_.begin(), shape_.end(),
1, std::multiplies<int>());
data_.resize(size, value);
cudnnCreateTensorDescriptor(&desc_);
switch (shape_.size()) {
case 1:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], 1, 1, 1); break;
case 2:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], 1, 1); break;
case 3:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], shape_[2], 1); break;
case 4:
cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(),
shape_[0], shape_[1], shape_[2], shape_[3]); break;
}
}
TensorImpl(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = delete;
~TensorImpl() {
cudnnDestroyTensorDescriptor(desc_);
}
value_type operator[](size_t i) const {
return data_[i];
}
auto begin() -> decltype( data_.begin() ) {
return data_.begin();
}
auto begin() const -> decltype( data_.begin() ) {
return data_.begin();
}
auto end() -> decltype( data_.end() ) {
return data_.end();
}
auto end() const -> decltype( data_.end() ) {
return data_.end();
}
const Shape& shape() const {
return shape_;
}
size_t size() const {
return data_.size();
}
value_type* data() {
return thrust::raw_pointer_cast(data_.data());
}
cudnnTensorDescriptor_t desc() const {
return desc_;
}
size_t id() const {
return tno_;
}
void set(value_type value) {
thrust::fill(data_.begin(), data_.end(), value);
}
};
template <typename Type>
size_t TensorImpl<Type>::tensorCounter = 0;
class Tensor {
private:
std::shared_ptr<TensorImpl<Float>> pimpl_;
public:
typedef TensorImpl<Float>::value_type value_type;
Tensor() {}
~Tensor() {}
void allocate(Shape shape, value_type value = 0) {
pimpl_.reset(new TensorImpl<Float>(shape, value));
}
value_type operator[](size_t i) const {
return (*pimpl_)[i];
}
size_t size() const {
return pimpl_->size();
}
value_type* data() {
return pimpl_->data();
}
const value_type* data() const {
return pimpl_->data();
}
auto begin() -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto begin() const -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto end() -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
auto end() const -> decltype( pimpl_->begin() ) {
return pimpl_->begin();
}
const Shape& shape() const {
return pimpl_->shape();
}
cudnnTensorDescriptor_t desc() const {
return pimpl_->desc();
}
void set(value_type value) {
pimpl_->set(value);
}
size_t id() const {
return pimpl_->id();
}
operator bool() {
return pimpl_ != nullptr;
}
};
Tensor uniform(Tensor t, Float a=-0.1, Float b=0.1) {
std::vector<Float> r(t.size());
for(int i = 0; i < r.size(); i++)
r[i] = (Float(rand() % 2000) - 1000.0)/10000.0;
thrust::copy(r.begin(), r.end(), t.begin());
return t;
};
using namespace thrust::placeholders;
#define MAX_THREADS 512
#define MAX_BLOCKS 65535
template <class Functor>
__global__ void gElement(Functor functor, Float* out,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i]);;
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
Float* out, const Float* in,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
const Float* rowIn = in + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn[i]);;
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
Float* out, const Float* in1, const Float* in2,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
const Float* rowIn1 = in1 + j * cols;
const Float* rowIn2 = in2 + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
}
}
}
}
template <class Functor>
__global__ void gElement(Functor functor,
Float* out, const Float* in1,
const Float* in2, const Float* in3,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
const Float* rowIn1 = in1 + j * cols;
const Float* rowIn2 = in2 + j * cols;
const Float* rowIn3 = in3 + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols)
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
}
}
}
}
// @TODO add broadcasting
template <class Functor>
void Element(Functor functor, Tensor Out) {
Float* d_out = Out.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In) {
Float* d_out = Out.data();
const Float* d_in = In.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In1, const Tensor In2) {
Float* d_out = Out.data();
const Float* d_in1 = In1.data();
const Float* d_in2 = In2.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
template <class Functor>
void Element(Functor functor,
Tensor Out, const Tensor In1,
const Tensor In2, const Tensor In3) {
Float* d_out = Out.data();
const Float* d_in1 = In1.data();
const Float* d_in2 = In2.data();
const Float* d_in3 = In3.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta) {
Float alpha = 1.0;
size_t m = A.shape()[0];
size_t k = A.shape()[1];
if(transA)
std::swap(m, k);
size_t l = B.shape()[0];
size_t n = B.shape()[1];
if(transB)
std::swap(l, n);
size_t lda = A.shape()[1];
size_t ldb = B.shape()[1];
size_t ldc = B.shape()[1];
if(transB)
ldc = B.shape()[0];
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasSgemm(handle, opB, opA,
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
return C;
}
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta = 0) {
return Prod(handles.cublasHandle, C, A, B, transA, transB, beta);
}
}

View File

@ -37,6 +37,16 @@ const Handles handles;
typedef std::vector<int> Shape;
inline std::string Debug(const Shape &shape)
{
std::stringstream strm;
assert(shape.size());
for (size_t i = 1; i < shape.size(); ++i) {
strm << "x" << shape[i];
}
return strm.str();
}
template<class Float>
class TensorImpl {
private:
@ -145,10 +155,7 @@ class TensorImpl {
{
std::stringstream strm;
assert(shape_.size());
strm << "shape=" << shape_[0];
for (size_t i = 1; i < shape_.size(); ++i) {
strm << "x" << shape_[i];
}
strm << "shape=" << marian::Debug(shape_);
return strm.str();
}
};

View File

@ -14,13 +14,17 @@ int main(int argc, char** argv) {
Expr w = param(shape={784, 10}, name="W0");
Expr b = param(shape={1, 10}, name="b0");
Expr lr = softmax(dot(x, w) + b, axis=1, name="pred");
Expr n5 = dot(x, w);
Expr n6 = n5 + b;
Expr lr = softmax(n6, axis=1, name="pred");
cerr << "lr=" << lr.Debug() << endl;
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
cerr << "tx=" << tx.Debug();
cerr << "ty=" << ty.Debug();
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
x = tx;
y = ty;