using cudnn dropout, about 10x faster than my own bad implementation

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-22 18:35:44 +02:00
parent c138c68e6a
commit 61c236237c
7 changed files with 98 additions and 189 deletions

View File

@ -3,7 +3,6 @@ include_directories(.)
cuda_add_library(marian_lib
cnpy/cnpy.cpp
dropout.cu
exception.cpp
expression_graph.cu
expression_operators.cu
@ -15,11 +14,6 @@ cuda_add_library(marian_lib
target_link_libraries(marian_lib)
cuda_add_executable(
dropout_benchmark
dropout_benchmark.cu
)
cuda_add_executable(
softmax_benchmark
softmax_benchmark.cu
@ -45,14 +39,13 @@ cuda_add_executable(
test_nodes.cu
)
target_link_libraries(dropout_benchmark marian_lib)
target_link_libraries(softmax_benchmark marian_lib)
target_link_libraries(mnist_benchmark marian_lib)
target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
target_link_libraries(test_nodes marian_lib)
foreach(exec dropout_benchmark mnist_benchmark softmax_benchmark validate_mnist_batch validate_encoder_decoder test_nodes )
foreach(exec mnist_benchmark softmax_benchmark validate_mnist_batch validate_encoder_decoder test_nodes )
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")

View File

@ -1,15 +0,0 @@
#include <curand.h>
#include <curand_kernel.h>
#include "dropout.h"
namespace marian {
__global__ void gInitCurandStates(curandState* states, unsigned int seed) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, tid, 0, &states[tid]);
}
unsigned Bernoulli::seed = time(0);
}

View File

@ -1,85 +0,0 @@
#pragma once
#include <curand.h>
#include <curand_kernel.h>
#include "tensor_operators.h"
namespace marian {
__global__ void gInitCurandStates(curandState* states, unsigned int seed);
class Bernoulli {
private:
float p_;
curandState* states_;
static unsigned seed;
Shape shape_;
public:
Bernoulli(float p, const Shape& shape)
: p_(p), shape_(shape) {}
void InitStates(curandState* states) {
states_ = states;
int blocks = std::min(MAX_BLOCKS, shape_[0]);
int threads = std::min(MAX_THREADS, shape_[1]);
int n = blocks * threads;
cudaMalloc((void**) &states_, n * sizeof(curandState));
gInitCurandStates<<<blocks, threads>>>(states_, seed++);
cudaStreamSynchronize(0);
}
void FreeStates(curandState* states) {
cudaFree(states);
}
__device__ float operator()(int i, int j) const {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
float dist = curand_uniform(&states_[tid]);
float zeroOne = dist > p_;
return zeroOne / (1 - p_);
}
__device__ int rows() const {
return shape_[0];
}
__device__ int cols() const {
return shape_[1];
}
Bernoulli& gpu() {
return *this;
}
};
template <class T1, class T2>
__global__ void gDropout(T1 out, T2 drop) {
int rows = out.rows();
int cols = out.cols();
for(int bid = 0; bid < rows; bid += gridDim.x) {
int i = bid + blockIdx.x;
if(i < rows) {
for(int tid = 0; tid < cols; tid += blockDim.x) {
int j = tid + threadIdx.x;
if(j < cols)
out(i, j) = drop(i, j);
}
}
}
}
template <class T1, class T2>
void Dropout(T1 out, T2 drop) {
int m = out.shape()[0];
int n = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, n);
gDropout<<<blocks, threads>>>(out.gpu(), drop.gpu());
cudaStreamSynchronize(0);
}
}

View File

@ -1,54 +0,0 @@
// This file is part of the Marian toolkit.
// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include <fstream>
#include <boost/timer/timer.hpp>
#include "marian.h"
#include "mnist.h"
#include "vocab.h"
#include "tensor_operators.h"
#include "curand.h"
using namespace marian;
using namespace keywords;
int main(int argc, char** argv) {
Tensor a({1000, 1000}, 3);
Tensor mask({1000, 1000});
Tensor b({1000, 1000});
Bernoulli dropout(0.2, mask.shape());
curandState* states = nullptr;
dropout.InitStates(states);
boost::timer::cpu_timer timer;
for(int i = 0; i < 1000; ++i) {
Dropout(mask, dropout);
Element(_1 = _2 * _3, b, mask, a);
}
std::cerr << timer.format(5, "%ws") << std::endl;
dropout.FreeStates(states);
return 0;
}

View File

@ -1,6 +1,5 @@
#include "node.h"
#include "tensor_operators.h"
#include "dropout.h"
namespace marian {
@ -112,32 +111,33 @@ struct DropoutNodeOp : public UnaryNodeOp {
template <typename ...Args>
DropoutNodeOp(Args ...args)
: UnaryNodeOp(args...),
p_(Get<float>(keywords::value, 0.5)) {}
allocated_(false), p_(Get<float>(keywords::value, 0.5)) {}
~DropoutNodeOp() {
if(bernoulli)
bernoulli->FreeStates(states_);
}
if(allocated_)
CudnnDropoutDestroy(dropDesc_, space_, states_);
}
void inference() {
Element(_1 = _2, val_, a_->val());
}
void forward() {
if(!bernoulli) {
bernoulli.reset(new Bernoulli(p_, val_.shape()));
bernoulli->InitStates(states_);
if(!allocated_) {
CudnnDropoutPrepare(a_->val(), p_,
&dropDesc_,
&space_, &spaceSize_,
&states_, (size_t)this); // seeding with pointer address
allocated_ = true;
}
if(!mask_)
mask_.allocate(val_.shape());
Dropout(mask_, *bernoulli);
Element(_1 = _2 * _3, val_, mask_, a_->val());
CudnnDropoutForward(dropDesc_, space_, spaceSize_,
val_, a_->val());
}
void backward() {
Element(_1 += _2 * _3, a_->grad(), adj_, mask_);
CudnnDropoutBackward(dropDesc_, space_, spaceSize_,
a_->grad(), adj_);
}
virtual std::string graphviz() {
@ -149,13 +149,14 @@ struct DropoutNodeOp : public UnaryNodeOp {
};
private:
bool allocated_;
float p_;
curandState* states_;
std::shared_ptr<Bernoulli> bernoulli;
Tensor mask_;
void* states_;
void* space_;
size_t spaceSize_;
cudnnDropoutDescriptor_t dropDesc_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)

View File

@ -210,7 +210,7 @@ __global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
@ -263,7 +263,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
@ -271,7 +271,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += adjRow[id];
_sum[threadIdx.x] += adjRow[id];
}
}
__syncthreads();
@ -348,22 +348,22 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
size_t k = A.shape()[1];
if(transA)
std::swap(m, k);
size_t l = B.shape()[0];
size_t n = B.shape()[1];
if(transB)
std::swap(l, n);
size_t lda = A.shape()[1];
size_t ldb = B.shape()[1];
size_t ldc = B.shape()[1];
if(transB)
ldc = B.shape()[0];
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasSgemm(handle, opB, opA,
n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc);
return C;
@ -394,5 +394,57 @@ Tensor SumRowwise(const Tensor A, Tensor result) {
return temp;
}
void CudnnDropoutPrepare(Tensor in, float p,
cudnnDropoutDescriptor_t* dropDesc,
void** space, size_t* spaceSize,
void** states, size_t seed) {
size_t statesSize;
cudnnDropoutGetStatesSize(cudnnHandle, &statesSize);
cudnnDropoutGetReserveSpaceSize(in.cudnn(), spaceSize);
}
cudaMalloc((void**)states, statesSize);
cudaMalloc((void**)space, *spaceSize);
cudnnCreateDropoutDescriptor(dropDesc);
cudnnSetDropoutDescriptor(*dropDesc,
cudnnHandle,
p,
(void*)*states,
statesSize,
seed);
}
void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc,
void* space, void* states) {
cudnnDestroyDropoutDescriptor(dropDesc);
cudaFree(space);
cudaFree(states);
}
void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc,
void* space, size_t spaceSize,
Tensor out, Tensor in) {
cudnnDropoutForward(cudnnHandle,
dropoutDesc,
in.cudnn(),
in.data(),
out.cudnn(),
out.data(),
space,
spaceSize);
}
void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc,
void* space, size_t spaceSize,
Tensor out, Tensor in) {
cudnnDropoutBackward(cudnnHandle,
dropoutDesc,
in.cudnn(),
in.data(),
out.cudnn(),
out.data(),
space,
spaceSize);
}
}

View File

@ -180,4 +180,21 @@ Tensor SumRowwise(const Tensor A, Tensor result);
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
void CudnnDropoutPrepare(Tensor in, float p,
cudnnDropoutDescriptor_t* dropDesc,
void** space, size_t* spaceSize,
void** states, size_t seed);
void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc,
void* space, void* states);
void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc,
void* space, size_t spaceSize,
Tensor out, Tensor in);
void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc,
void* space, size_t spaceSize,
Tensor out, Tensor in);
}