removed dependencies on extended lambdas, fixes #15

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-22 15:16:03 +02:00
parent ade449a06a
commit 56e241695c
5 changed files with 62 additions and 37 deletions

View File

@ -3,7 +3,7 @@ set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(marian CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; --expt-extended-lambda; --expt-relaxed-constexpr; -Xcompiler '-fPIC')
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC')
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)

View File

@ -15,11 +15,11 @@ class Bernoulli {
curandState* states_;
static unsigned seed;
Shape shape_;
public:
Bernoulli(float p, const Shape& shape)
: p_(p), shape_(shape) {}
void InitStates(curandState* states) {
states_ = states;
int blocks = std::min(MAX_BLOCKS, shape_[0]);
@ -29,29 +29,57 @@ class Bernoulli {
gInitCurandStates<<<blocks, threads>>>(states_, seed++);
cudaStreamSynchronize(0);
}
void FreeStates(curandState* states) {
cudaFree(states);
}
__device__ float operator()(int i, int j) const {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
float dist = curand_uniform(&states_[tid]);
float dist = curand_uniform(&states_[tid]);
float zeroOne = dist > p_;
return zeroOne / (1 - p_);
}
__device__ int rows() const {
return shape_[0];
}
__device__ int cols() const {
return shape_[1];
}
Bernoulli& gpu() {
return *this;
}
};
}
template <class T1, class T2>
__global__ void gDropout(T1 out, T2 drop) {
int rows = out.rows();
int cols = out.cols();
for(int bid = 0; bid < rows; bid += gridDim.x) {
int i = bid + blockIdx.x;
if(i < rows) {
for(int tid = 0; tid < cols; tid += blockDim.x) {
int j = tid + threadIdx.x;
if(j < cols)
out(i, j) = drop(i, j);
}
}
}
}
template <class T1, class T2>
void Dropout(T1 out, T2 drop) {
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gDropout<<<blocks, threads>>>(out.gpu(), drop.gpu());
cudaStreamSynchronize(0);
}
}

View File

@ -32,21 +32,23 @@ using namespace marian;
using namespace keywords;
int main(int argc, char** argv) {
Tensor a({1000, 1000}, 3);
Tensor mask({1000, 1000});
Tensor b({1000, 1000});
Bernoulli dropout(0.2, b.shape());
auto f = [] __device__ (float& r,
float a,
float b) {
return r = a * b;
};
Bernoulli dropout(0.2, mask.shape());
curandState* states = nullptr;
dropout.InitStates(states);
boost::timer::cpu_timer timer;
for(int i = 0; i < 1000; ++i)
Element(f, b, a, a);
for(int i = 0; i < 1000; ++i) {
Dropout(mask, dropout);
Element(_1 = _2 * _3, b, mask, a);
}
std::cerr << timer.format(5, "%ws") << std::endl;
dropout.FreeStates(states);
return 0;
}

View File

@ -108,7 +108,6 @@ struct ReLUNodeOp : public UnaryNodeOp {
};
// Scaling droput
struct DropoutNodeOp : public UnaryNodeOp {
template <typename ...Args>
DropoutNodeOp(Args ...args)
@ -119,28 +118,25 @@ struct DropoutNodeOp : public UnaryNodeOp {
if(bernoulli)
bernoulli->FreeStates(states_);
}
void inference() {
Element(_1 = _2, val_, a_->val());
}
void forward() {
if(!bernoulli) {
bernoulli.reset(new Bernoulli(p_, val_.shape()));
bernoulli->InitStates(states_);
}
if(!mask_)
mask_.allocate(val_.shape());
auto f = [] __device__ (float& mask, float drop) {
return mask = drop;
};
Element(f, mask_, *bernoulli);
Dropout(mask_, *bernoulli);
Element(_1 = _2 * _3, val_, mask_, a_->val());
}
void backward() {
void backward() {
Element(_1 += _2 * _3, a_->grad(), adj_, mask_);
}
@ -321,4 +317,3 @@ struct NegNodeOp : public UnaryNodeOp {
}

View File

@ -51,7 +51,7 @@ void Element(Functor functor, T out) {
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gElement<<<blocks, threads>>>(functor, out.gpu());
@ -82,7 +82,7 @@ void Element(Functor functor,
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gElement<<<blocks, threads>>>(functor, out.gpu(), in.gpu());
@ -112,7 +112,7 @@ void Element(Functor functor,
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gElement<<<blocks, threads>>>(functor, out.gpu(),
@ -143,7 +143,7 @@ void Element(Functor functor,
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gElement<<<blocks, threads>>>(functor, out.gpu(),