mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
removed dependencies on extended lambdas, fixes #15
This commit is contained in:
parent
ade449a06a
commit
56e241695c
@ -3,7 +3,7 @@ set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||
|
||||
project(marian CXX)
|
||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; --expt-extended-lambda; --expt-relaxed-constexpr; -Xcompiler '-fPIC')
|
||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC')
|
||||
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
|
@ -15,11 +15,11 @@ class Bernoulli {
|
||||
curandState* states_;
|
||||
static unsigned seed;
|
||||
Shape shape_;
|
||||
|
||||
|
||||
public:
|
||||
Bernoulli(float p, const Shape& shape)
|
||||
: p_(p), shape_(shape) {}
|
||||
|
||||
|
||||
void InitStates(curandState* states) {
|
||||
states_ = states;
|
||||
int blocks = std::min(MAX_BLOCKS, shape_[0]);
|
||||
@ -29,29 +29,57 @@ class Bernoulli {
|
||||
gInitCurandStates<<<blocks, threads>>>(states_, seed++);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
|
||||
void FreeStates(curandState* states) {
|
||||
cudaFree(states);
|
||||
}
|
||||
|
||||
|
||||
__device__ float operator()(int i, int j) const {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
float dist = curand_uniform(&states_[tid]);
|
||||
float dist = curand_uniform(&states_[tid]);
|
||||
float zeroOne = dist > p_;
|
||||
return zeroOne / (1 - p_);
|
||||
}
|
||||
|
||||
|
||||
__device__ int rows() const {
|
||||
return shape_[0];
|
||||
}
|
||||
|
||||
|
||||
__device__ int cols() const {
|
||||
return shape_[1];
|
||||
}
|
||||
|
||||
|
||||
Bernoulli& gpu() {
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
template <class T1, class T2>
|
||||
__global__ void gDropout(T1 out, T2 drop) {
|
||||
int rows = out.rows();
|
||||
int cols = out.cols();
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int i = bid + blockIdx.x;
|
||||
if(i < rows) {
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int j = tid + threadIdx.x;
|
||||
if(j < cols)
|
||||
out(i, j) = drop(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
void Dropout(T1 out, T2 drop) {
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gDropout<<<blocks, threads>>>(out.gpu(), drop.gpu());
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -32,21 +32,23 @@ using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
|
||||
|
||||
Tensor a({1000, 1000}, 3);
|
||||
Tensor mask({1000, 1000});
|
||||
Tensor b({1000, 1000});
|
||||
Bernoulli dropout(0.2, b.shape());
|
||||
|
||||
auto f = [] __device__ (float& r,
|
||||
float a,
|
||||
float b) {
|
||||
return r = a * b;
|
||||
};
|
||||
|
||||
Bernoulli dropout(0.2, mask.shape());
|
||||
|
||||
curandState* states = nullptr;
|
||||
dropout.InitStates(states);
|
||||
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 0; i < 1000; ++i)
|
||||
Element(f, b, a, a);
|
||||
|
||||
for(int i = 0; i < 1000; ++i) {
|
||||
Dropout(mask, dropout);
|
||||
Element(_1 = _2 * _3, b, mask, a);
|
||||
}
|
||||
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
dropout.FreeStates(states);
|
||||
return 0;
|
||||
}
|
||||
|
@ -108,7 +108,6 @@ struct ReLUNodeOp : public UnaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
// Scaling droput
|
||||
struct DropoutNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DropoutNodeOp(Args ...args)
|
||||
@ -119,28 +118,25 @@ struct DropoutNodeOp : public UnaryNodeOp {
|
||||
if(bernoulli)
|
||||
bernoulli->FreeStates(states_);
|
||||
}
|
||||
|
||||
|
||||
void inference() {
|
||||
Element(_1 = _2, val_, a_->val());
|
||||
}
|
||||
|
||||
|
||||
void forward() {
|
||||
if(!bernoulli) {
|
||||
bernoulli.reset(new Bernoulli(p_, val_.shape()));
|
||||
bernoulli->InitStates(states_);
|
||||
}
|
||||
|
||||
|
||||
if(!mask_)
|
||||
mask_.allocate(val_.shape());
|
||||
|
||||
auto f = [] __device__ (float& mask, float drop) {
|
||||
return mask = drop;
|
||||
};
|
||||
Element(f, mask_, *bernoulli);
|
||||
Dropout(mask_, *bernoulli);
|
||||
Element(_1 = _2 * _3, val_, mask_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3, a_->grad(), adj_, mask_);
|
||||
}
|
||||
|
||||
@ -321,4 +317,3 @@ struct NegNodeOp : public UnaryNodeOp {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ void Element(Functor functor, T out) {
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, out.gpu());
|
||||
@ -82,7 +82,7 @@ void Element(Functor functor,
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, out.gpu(), in.gpu());
|
||||
@ -112,7 +112,7 @@ void Element(Functor functor,
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, out.gpu(),
|
||||
@ -143,7 +143,7 @@ void Element(Functor functor,
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, out.gpu(),
|
||||
|
Loading…
Reference in New Issue
Block a user