This commit is contained in:
Andre Martins 2016-09-21 16:35:20 +01:00
commit 5b4a50f8b0
11 changed files with 350 additions and 76 deletions

View File

@ -20,6 +20,11 @@ cuda_add_executable(
test.cu
)
cuda_add_executable(
test_cudnn
cudnn.cu
)
cuda_add_executable(
mnist_benchmark
mnist_benchmark.cu
@ -41,12 +46,13 @@ cuda_add_executable(
)
target_link_libraries(marian marian_lib)
target_link_libraries(test_cudnn marian_lib)
target_link_libraries(mnist_benchmark marian_lib)
target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
target_link_libraries(test_nodes marian_lib)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes test_cudnn)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")

149
src/cudnn.cu Normal file
View File

@ -0,0 +1,149 @@
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <cudnn.h>
#include <boost/timer/timer.hpp>
#include "tensor.h"
#include "tensor_operators.h"
#include "param_initializers.h"
using namespace marian;
void CudnnSoftmaxForward(cudnnHandle_t cudnnHandle,
Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnSoftmaxBackward(cudnnHandle_t cudnnHandle,
Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
out.cudnn(),
out.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
int main() {
cudnnHandle_t cudnnHandle;
cudnnCreate(&cudnnHandle);
int d = 10;
Tensor in({d, d});
Tensor out({d, d});
Tensor grad({d, d});
Tensor adj({d, d}, 1);
auto f = uniform(-5, 5);
f(in);
std::cerr << in.Debug() << std::endl;
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
CudnnSoftmaxForward(cudnnHandle, out, in);
std::cerr << out.Debug() << std::endl;
CudnnSoftmaxBackward(cudnnHandle, grad, in);
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Element(_1 = _2, out, in);
Softmax(&out);
std::cerr << out.Debug() << std::endl;
SoftmaxGrad(grad, adj, out);
std::cerr << grad.Debug() << std::endl;
}
//std::cerr << grad.Debug() << std::endl;
std::cerr << timer.format(5, "%ws") << std::endl;
}
//// Copy back
//float *result = (float *) malloc(m * c * sizeof(float));
//cudaMemcpy(result, d_softmaxData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
//cudaDeviceSynchronize();
//
//// Log
//printf("SOFTMAX:\n");
//printMatrix(result, c, m);
//
//// Try backward
//cudnnTensorDescriptor_t diffTensorDesc;
//cudnnCreateTensorDescriptor(&diffTensorDesc);
//cudnnSetTensor4dDescriptor(diffTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
// m, c, 1, 1);
//
//float *d_gradData;
//cudaMalloc((void**) &d_gradData, m * c * sizeof(float));
//
//float *diffData = makeDiffData(m, c);
//float *d_diffData;
//cudaMalloc((void**) &d_diffData, m * c * sizeof(float));
//cudaMemcpy(d_diffData, diffData, m * c * sizeof(float), cudaMemcpyHostToDevice);
//cudaDeviceSynchronize();
//
//cudnnSoftmaxBackward(cudnnHandle,
// CUDNN_SOFTMAX_ACCURATE,
// CUDNN_SOFTMAX_MODE_CHANNEL,
// &alpha,
// srcTensorDesc,
// d_softmaxData,
// diffTensorDesc,
// d_diffData,
// &beta,
// sftTensorDesc,
// d_gradData);
//cudaDeviceSynchronize();
//
//// Copy back
//float *result_backward = (float *) malloc(m * c * sizeof(float));
//cudaMemcpy(result_backward, d_gradData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
//cudaDeviceSynchronize();
//
//// Log
//printf("GRADIENT:\n");
//printMatrix(result_backward, c, m);
//
//// Destruct
//free(result);
//free(diffData);
//free(result_backward);
//free(fcLayer);
//
//cudnnDestroyTensorDescriptor(srcTensorDesc);
//cudnnDestroyTensorDescriptor(sftTensorDesc);
//cudnnDestroyTensorDescriptor(diffTensorDesc);
//cudaFree(d_fcLayer);
//cudaFree(d_softmaxData);
//cudaFree(d_gradData);
//cudaFree(d_diffData);
cudnnDestroy(cudnnHandle);
return 0;
}

View File

@ -122,7 +122,7 @@ int main(int argc, char** argv) {
boost::timer::cpu_timer total;
Adam opt(0.0002);
for(int i = 1; i <= 30; ++i) {
for(int i = 1; i <= 50; ++i) {
boost::timer::cpu_timer timer;
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
float cost = 0;

View File

@ -13,54 +13,67 @@ void Node::calc_numeric_grad(
using namespace std;
size_t inputSize = GetTotalSize(input.shape());
size_t gradSize = GetTotalSize(grad.shape());
size_t adjSize = GetTotalSize(adj_.shape());
size_t valSize = GetTotalSize(val_.shape());
UTIL_THROW_IF2(inputSize != GetTotalSize(grad.shape()),
"inputSize != gradSize:" << inputSize << "!=" << GetTotalSize(grad.shape()));
UTIL_THROW_IF2(valSize != GetTotalSize(adj_.shape()),
"valSize != adjSize :" << valSize << "!=" << GetTotalSize(adj_.shape()));
cerr << "sizes: "
<< Debug(input.shape())<< "=" << inputSize << " "
<< Debug(grad.shape()) << "=" << gradSize << " "
<< Debug(adj_.shape()) << "=" << adjSize
<< Debug(val_.shape()) << "=" << valSize
<< endl;
std::vector<float> diffGrad(gradSize);
thrust::copy(grad.begin(), grad.end(), diffGrad.begin());
cerr << "diffGrad=" << grad.Debug() << endl;
//cerr << "input=" << input.Debug() << endl;
std::vector<float> origGrad(inputSize);
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
cerr << "origGrad=" << grad.Debug() << endl;
//output("diffGrad", diffGrad);
// reset grad
thrust::copy(prevCalcGrad.begin(), prevCalcGrad.end(), grad.begin());
//cerr << "reset a_->grad()=" << a_->grad().Debug() << endl;
//output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end());
// START CALC of numerical gradient
// new values
input.incr(delta);
std::vector<float> inputVec(inputSize);
thrust::copy(input.begin(), input.end(), inputVec.begin());
//output("inputVec", inputVec);
std::vector<float> newVal(inputSize, 0);
// LOOP thru each element in input & add delta
for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) {
inputVec[inputInd] += delta;
thrust::copy(inputVec.begin(), inputVec.end(), input.begin());
forward();
for (size_t i = 0; i < valSize; ++i) {
newVal[inputInd] += val_[i];
}
inputVec[inputInd] -= delta;
}
// orig value
thrust::copy(inputVec.begin(), inputVec.end(), input.begin());
forward();
//cerr << "input=" << input.Debug() << endl;
//cerr << "val_=" << val_.Debug() << endl;
std::vector<float> newVal(inputSize);
thrust::copy(val_.begin(), val_.end(), newVal.begin());
//output("newVal", newVal);
Float sumValOrig = 0;
for (size_t i = 0; i < valSize; ++i) {
sumValOrig += val_[i];
}
// old values
input.incr(-delta);
forward();
//cerr << "input=" << input.Debug() << endl;
//cerr << "val_=" << val_.Debug() << endl;
std::vector<float> origVal(inputSize);
thrust::copy(val_.begin(), val_.end(), origVal.begin());
//output("origVal", origVal);
//output("newVal", newVal.begin(), newVal.end());
// calc gradient
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> adjVec(adjSize);
std::vector<float> adjVec(valSize);
thrust::copy(adj_.begin(), adj_.end(), adjVec.begin());
std::vector<float> numericalGrad(gradSize);
std::vector<float> numericalGrad(inputSize);
for (size_t i = 0; i < numericalGrad.size(); ++i) {
numericalGrad[i] = prevCalcGrad[i] + (adjVec[i] * (newVal[i] - origVal[i]) / delta);
numericalGrad[i] = (adjVec[i] * (newVal[i] - sumValOrig) / delta);
numericalGrad[i] += prevCalcGrad[i];
}
// set grad results
@ -68,15 +81,16 @@ void Node::calc_numeric_grad(
cerr << "numericalGrad=" << grad.Debug() << endl;
//output("numericalGrad", numericalGrad);
// print out diff between diffGrad and numericalGrad
std::vector<float> origGrad(gradSize);
std::vector<float> diff(gradSize);
// print out diff between origGrad and numericalGrad
std::vector<float> diff(inputSize);
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
for (size_t i = 0; i < diff.size(); ++i) {
diff[i] = (diffGrad[i] - numericalGrad[i]) ;
diff[i] = (origGrad[i] - numericalGrad[i]) ;
}
output("diff", diff);
output("diff", diff.begin(), diff.end());
// put back origGrad
thrust::copy(origGrad.begin(), origGrad.end(), grad.begin());
}
@ -88,15 +102,6 @@ std::vector<float> Node::StoreTensorInVec(Tensor tensor)
return vec;
}
void Node::output(const std::string &title, const std::vector<float> &vec)
{
std::cerr << title << "(" << vec.size() << "): ";
for (size_t i = 0; i < vec.size(); ++i) {
std::cerr << vec[i] << " ";
}
std::cerr << std::endl;
}
}

View File

@ -112,7 +112,16 @@ class Node : public Chainable<Tensor>,
Tensor val_;
Tensor adj_;
void output(const std::string &title, const std::vector<float> &vec);
template<class ITER>
void output(const std::string &title, const ITER &b, const ITER &e) const
{
std::cerr << title << ": ";
for (ITER iter = b; iter != e; ++iter) {
std::cerr << *iter << " ";
}
std::cerr << std::endl;
}
std::vector<float> StoreTensorInVec(Tensor tensor);
void calc_numeric_grad(
Float delta,

View File

@ -26,14 +26,13 @@ struct BinaryNodeOp : public Node {
// use df/dx to calc grad
backward();
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
//cerr << "orig b_->grad()=" << b_->grad().Debug() << endl;
cerr << "TENSOR A:" << endl;
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
cerr << "TENSOR B:" << endl;
calc_numeric_grad(delta, b_->val(), b_->grad(), preCalcGradB);
// redo proper grad
backward();
}
@ -249,11 +248,11 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
CudnnLogSoftmax(probs_, a_->val());
if(!result_)
result_.allocate(a_->val().shape());
Element(_1 = -_2 * Log(_3), result_, b_->val(), probs_);
Element(_1 = -_2 * _3, result_, b_->val(), probs_);
SumRowwise(result_, val_);
}
@ -262,22 +261,19 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// We are using logsoftmax for this and cached probs are logs.
// For each row, the first input derivative is given by adj * (exp(p) - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
if(!result_)
result_.allocate(probs_.shape());
// The second input derivative is -adj*p.
// Compute first input derivative.
Element(_1 = _2 - _3, result_, probs_, b_->val());
ScaleRowwise(result_, adj_);
Element(_1 += _2, a_->grad(), result_);
Element(_1 += _2 * (Exp(_3) - _4),
a_->grad(), adj_, probs_, b_->val());
// Compute second input derivative.
Element(_1 = -Log(_2), result_, probs_); // @TODO: use a cached log here.
ScaleRowwise(result_, adj_);
Element(_1 += _2, b_->grad(), result_);
Element(_1 -= _2 * _3, b_->grad(),
adj_, probs_);
}
virtual std::string graphviz() {

View File

@ -22,6 +22,7 @@ struct UnaryNodeOp : public Node {
// use df/dx to calc grad
backward();
//cerr << "orig a_->val()=" << a_->val().Debug() << endl;
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
@ -284,6 +285,8 @@ struct NegNodeOp : public UnaryNodeOp {
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
//std::cerr << "a_->grad=" << a_->grad().Debug() << std::endl;
}
virtual std::string graphviz() {

View File

@ -22,6 +22,8 @@
// SOFTWARE.
#include <cublas_v2.h>
#include <cudnn.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <numeric>
@ -77,6 +79,9 @@ class TensorImpl {
thrust::device_vector<Float> data_; /*< Vector of data that Tensor is managing on GPU. */
size_t tno_; /*< Tensor number */
static size_t tensorCounter; /*< Static counter of created Tensors */
// cuDNN stuff
cudnnTensorDescriptor_t cudnnDesc_;
public:
typedef Float value_type; /*< Tensor value type */
@ -100,11 +105,20 @@ class TensorImpl {
int size = GetTotalSize(shape_);
data_.resize(size, value);
cudnnCreateTensorDescriptor(&cudnnDesc_);
cudnnSetTensor4dDescriptorEx(cudnnDesc_, CUDNN_DATA_FLOAT,
shape_[0], shape_[1], 1, 1,
shape_[1], 1, 1, 1);
}
TensorImpl(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = delete;
~TensorImpl() {
cudnnDestroyTensorDescriptor(cudnnDesc_);
}
/**
* @brief Get the i-th element of Tensor vector.
*
@ -208,12 +222,6 @@ class TensorImpl {
thrust::copy(begin, end, data_.begin());
}
void incr(Float incr) {
for (size_t i = 0; i < data_.size(); ++i) {
data_[i] += incr;
}
}
/**
* @brief Copy Tensor's vector from GPU to vector variable on CPU.
*
@ -249,6 +257,11 @@ class TensorImpl {
}
return strm.str();
}
cudnnTensorDescriptor_t cudnn() {
return cudnnDesc_;
}
};
template <typename Type>
@ -438,11 +451,6 @@ class Tensor {
*/
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
void incr(Float incr) {
pimpl_->incr(incr)
;
}
/**
* @brief Copy Tensor's vector from GPU to vector variable on CPU (const).
*
@ -494,6 +502,10 @@ class Tensor {
TensorView gpu() {
return TensorView(*this);
}
cudnnTensorDescriptor_t cudnn() {
return pimpl_->cudnn();
}
};
/**

View File

@ -29,7 +29,15 @@ static cublasHandle_t create_handle() {
cublasCreate(&cublasHandle);
return cublasHandle;
}
static cudnnHandle_t create_handle_dnn() {
cudnnHandle_t cudnnHandle;
cudnnCreate(&cudnnHandle);
return cudnnHandle;
}
cublasHandle_t cublasHandle = create_handle();
cudnnHandle_t cudnnHandle = create_handle_dnn();
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
@ -84,6 +92,60 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaStreamSynchronize(0);
}
__global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp becaus we chached logsoftmax
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
gradRow[id] += adjRow[id] - _sum[0];
}
}
}
}
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
// grad and val are both m-by-k matrices, passed as input.
// A weighted average of each row of grad (according to the weights
// specified in val) is computed and subtracted from Out.
// adj is multiplied for each element to get backward step in autodiff
int m = grad.shape()[0];
int k = grad.shape()[1];
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, k);
int shared = sizeof(float) * threads * 2;
gLogSoftmaxGrad<<<blocks, threads, shared>>>(grad.data(),
adj.data(), val.data(),
m, k);
cudaStreamSynchronize(0);
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
@ -346,4 +408,32 @@ void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
cudaStreamSynchronize(0);
}
void CudnnSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
}

View File

@ -158,6 +158,10 @@ void SubtractMax(Tensor* Out);
void Softmax(Tensor* Out);
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void CudnnSoftmax(Tensor out, Tensor in);
void CudnnLogSoftmax(Tensor out, Tensor in);
void Argmax(Tensor* Out, const Tensor* In);

View File

@ -55,7 +55,7 @@ int main(int argc, char** argv)
// train
g.forward(batch_size);
//g.backward();
g.backward_debug(0.00001);
g.backward_debug(0.001);
std::cout << g.graphviz() << std::endl;