diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3f6eb8ff..15e5c19e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,6 +20,11 @@ cuda_add_executable( test.cu ) +cuda_add_executable( + test_cudnn + cudnn.cu +) + cuda_add_executable( mnist_benchmark mnist_benchmark.cu @@ -41,12 +46,13 @@ cuda_add_executable( ) target_link_libraries(marian marian_lib) +target_link_libraries(test_cudnn marian_lib) target_link_libraries(mnist_benchmark marian_lib) target_link_libraries(validate_mnist_batch marian_lib) target_link_libraries(validate_encoder_decoder marian_lib) target_link_libraries(test_nodes marian_lib) -foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes) +foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes test_cudnn) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand) cuda_add_cublas_to_target(${exec}) set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") diff --git a/src/cudnn.cu b/src/cudnn.cu new file mode 100644 index 00000000..49db7372 --- /dev/null +++ b/src/cudnn.cu @@ -0,0 +1,149 @@ +#include +#include +#include +#include + +#include + +#include "tensor.h" +#include "tensor_operators.h" +#include "param_initializers.h" + +using namespace marian; + +void CudnnSoftmaxForward(cudnnHandle_t cudnnHandle, + Tensor out, Tensor in) { + float alpha = 1, beta = 0; + cudnnSoftmaxForward(cudnnHandle, + CUDNN_SOFTMAX_LOG, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + in.cudnn(), + in.data(), + &beta, + out.cudnn(), + out.data()); + cudaDeviceSynchronize(); +} + +void CudnnSoftmaxBackward(cudnnHandle_t cudnnHandle, + Tensor out, Tensor in) { + float alpha = 1, beta = 0; + cudnnSoftmaxBackward(cudnnHandle, + CUDNN_SOFTMAX_LOG, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + in.cudnn(), + in.data(), + out.cudnn(), + out.data(), + &beta, + out.cudnn(), + out.data()); + cudaDeviceSynchronize(); +} + +int main() { + cudnnHandle_t cudnnHandle; + cudnnCreate(&cudnnHandle); + + int d = 10; + + Tensor in({d, d}); + Tensor out({d, d}); + Tensor grad({d, d}); + Tensor adj({d, d}, 1); + + auto f = uniform(-5, 5); + f(in); + + std::cerr << in.Debug() << std::endl; + + { + boost::timer::cpu_timer timer; + for(int i = 0; i < 1; ++i) { + CudnnSoftmaxForward(cudnnHandle, out, in); + std::cerr << out.Debug() << std::endl; + CudnnSoftmaxBackward(cudnnHandle, grad, in); + std::cerr << grad.Debug() << std::endl; + } + + std::cerr << timer.format(5, "%ws") << std::endl; + } + + { + boost::timer::cpu_timer timer; + for(int i = 0; i < 1; ++i) { + Element(_1 = _2, out, in); + Softmax(&out); + std::cerr << out.Debug() << std::endl; + SoftmaxGrad(grad, adj, out); + std::cerr << grad.Debug() << std::endl; + } + //std::cerr << grad.Debug() << std::endl; + std::cerr << timer.format(5, "%ws") << std::endl; + } + + + //// Copy back + //float *result = (float *) malloc(m * c * sizeof(float)); + //cudaMemcpy(result, d_softmaxData, m * c * sizeof(float), cudaMemcpyDeviceToHost); + //cudaDeviceSynchronize(); + // + //// Log + //printf("SOFTMAX:\n"); + //printMatrix(result, c, m); + // + //// Try backward + //cudnnTensorDescriptor_t diffTensorDesc; + //cudnnCreateTensorDescriptor(&diffTensorDesc); + //cudnnSetTensor4dDescriptor(diffTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, + // m, c, 1, 1); + // + //float *d_gradData; + //cudaMalloc((void**) &d_gradData, m * c * sizeof(float)); + // + //float *diffData = makeDiffData(m, c); + //float *d_diffData; + //cudaMalloc((void**) &d_diffData, m * c * sizeof(float)); + //cudaMemcpy(d_diffData, diffData, m * c * sizeof(float), cudaMemcpyHostToDevice); + //cudaDeviceSynchronize(); + // + //cudnnSoftmaxBackward(cudnnHandle, + // CUDNN_SOFTMAX_ACCURATE, + // CUDNN_SOFTMAX_MODE_CHANNEL, + // &alpha, + // srcTensorDesc, + // d_softmaxData, + // diffTensorDesc, + // d_diffData, + // &beta, + // sftTensorDesc, + // d_gradData); + //cudaDeviceSynchronize(); + // + //// Copy back + //float *result_backward = (float *) malloc(m * c * sizeof(float)); + //cudaMemcpy(result_backward, d_gradData, m * c * sizeof(float), cudaMemcpyDeviceToHost); + //cudaDeviceSynchronize(); + // + //// Log + //printf("GRADIENT:\n"); + //printMatrix(result_backward, c, m); + // + //// Destruct + //free(result); + //free(diffData); + //free(result_backward); + //free(fcLayer); + // + //cudnnDestroyTensorDescriptor(srcTensorDesc); + //cudnnDestroyTensorDescriptor(sftTensorDesc); + //cudnnDestroyTensorDescriptor(diffTensorDesc); + //cudaFree(d_fcLayer); + //cudaFree(d_softmaxData); + //cudaFree(d_gradData); + //cudaFree(d_diffData); + cudnnDestroy(cudnnHandle); + return 0; +} \ No newline at end of file diff --git a/src/mnist_benchmark.cu b/src/mnist_benchmark.cu index c40d07ad..0dcaf9da 100644 --- a/src/mnist_benchmark.cu +++ b/src/mnist_benchmark.cu @@ -122,7 +122,7 @@ int main(int argc, char** argv) { boost::timer::cpu_timer total; Adam opt(0.0002); - for(int i = 1; i <= 30; ++i) { + for(int i = 1; i <= 50; ++i) { boost::timer::cpu_timer timer; shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE); float cost = 0; diff --git a/src/node.cu b/src/node.cu index f82c68e7..7c9318db 100644 --- a/src/node.cu +++ b/src/node.cu @@ -13,54 +13,67 @@ void Node::calc_numeric_grad( using namespace std; size_t inputSize = GetTotalSize(input.shape()); - size_t gradSize = GetTotalSize(grad.shape()); - size_t adjSize = GetTotalSize(adj_.shape()); + size_t valSize = GetTotalSize(val_.shape()); + + UTIL_THROW_IF2(inputSize != GetTotalSize(grad.shape()), + "inputSize != gradSize:" << inputSize << "!=" << GetTotalSize(grad.shape())); + UTIL_THROW_IF2(valSize != GetTotalSize(adj_.shape()), + "valSize != adjSize :" << valSize << "!=" << GetTotalSize(adj_.shape())); + cerr << "sizes: " << Debug(input.shape())<< "=" << inputSize << " " - << Debug(grad.shape()) << "=" << gradSize << " " - << Debug(adj_.shape()) << "=" << adjSize + << Debug(val_.shape()) << "=" << valSize << endl; - std::vector diffGrad(gradSize); - thrust::copy(grad.begin(), grad.end(), diffGrad.begin()); - cerr << "diffGrad=" << grad.Debug() << endl; + //cerr << "input=" << input.Debug() << endl; + + std::vector origGrad(inputSize); + thrust::copy(grad.begin(), grad.end(), origGrad.begin()); + cerr << "origGrad=" << grad.Debug() << endl; //output("diffGrad", diffGrad); - // reset grad - thrust::copy(prevCalcGrad.begin(), prevCalcGrad.end(), grad.begin()); - //cerr << "reset a_->grad()=" << a_->grad().Debug() << endl; + //output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end()); - // START CALC of numerical gradient - // new values - input.incr(delta); + std::vector inputVec(inputSize); + thrust::copy(input.begin(), input.end(), inputVec.begin()); + //output("inputVec", inputVec); + std::vector newVal(inputSize, 0); + + // LOOP thru each element in input & add delta + for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) { + inputVec[inputInd] += delta; + thrust::copy(inputVec.begin(), inputVec.end(), input.begin()); + + forward(); + + for (size_t i = 0; i < valSize; ++i) { + newVal[inputInd] += val_[i]; + } + + inputVec[inputInd] -= delta; + } + + // orig value + thrust::copy(inputVec.begin(), inputVec.end(), input.begin()); forward(); - //cerr << "input=" << input.Debug() << endl; - //cerr << "val_=" << val_.Debug() << endl; - std::vector newVal(inputSize); - thrust::copy(val_.begin(), val_.end(), newVal.begin()); - //output("newVal", newVal); + Float sumValOrig = 0; + for (size_t i = 0; i < valSize; ++i) { + sumValOrig += val_[i]; + } - // old values - input.incr(-delta); - - forward(); - //cerr << "input=" << input.Debug() << endl; - //cerr << "val_=" << val_.Debug() << endl; - - std::vector origVal(inputSize); - thrust::copy(val_.begin(), val_.end(), origVal.begin()); - //output("origVal", origVal); + //output("newVal", newVal.begin(), newVal.end()); // calc gradient //cerr << "adj_=" << adj_.Debug() << endl; - std::vector adjVec(adjSize); + std::vector adjVec(valSize); thrust::copy(adj_.begin(), adj_.end(), adjVec.begin()); - std::vector numericalGrad(gradSize); + std::vector numericalGrad(inputSize); for (size_t i = 0; i < numericalGrad.size(); ++i) { - numericalGrad[i] = prevCalcGrad[i] + (adjVec[i] * (newVal[i] - origVal[i]) / delta); + numericalGrad[i] = (adjVec[i] * (newVal[i] - sumValOrig) / delta); + numericalGrad[i] += prevCalcGrad[i]; } // set grad results @@ -68,15 +81,16 @@ void Node::calc_numeric_grad( cerr << "numericalGrad=" << grad.Debug() << endl; //output("numericalGrad", numericalGrad); - // print out diff between diffGrad and numericalGrad - std::vector origGrad(gradSize); - std::vector diff(gradSize); + // print out diff between origGrad and numericalGrad + std::vector diff(inputSize); - thrust::copy(grad.begin(), grad.end(), origGrad.begin()); for (size_t i = 0; i < diff.size(); ++i) { - diff[i] = (diffGrad[i] - numericalGrad[i]) ; + diff[i] = (origGrad[i] - numericalGrad[i]) ; } - output("diff", diff); + output("diff", diff.begin(), diff.end()); + + // put back origGrad + thrust::copy(origGrad.begin(), origGrad.end(), grad.begin()); } @@ -88,15 +102,6 @@ std::vector Node::StoreTensorInVec(Tensor tensor) return vec; } -void Node::output(const std::string &title, const std::vector &vec) -{ - std::cerr << title << "(" << vec.size() << "): "; - for (size_t i = 0; i < vec.size(); ++i) { - std::cerr << vec[i] << " "; - } - std::cerr << std::endl; -} - } diff --git a/src/node.h b/src/node.h index 575f2556..a3264929 100644 --- a/src/node.h +++ b/src/node.h @@ -112,7 +112,16 @@ class Node : public Chainable, Tensor val_; Tensor adj_; - void output(const std::string &title, const std::vector &vec); + template + void output(const std::string &title, const ITER &b, const ITER &e) const + { + std::cerr << title << ": "; + for (ITER iter = b; iter != e; ++iter) { + std::cerr << *iter << " "; + } + std::cerr << std::endl; + } + std::vector StoreTensorInVec(Tensor tensor); void calc_numeric_grad( Float delta, diff --git a/src/node_operators_binary.h b/src/node_operators_binary.h index 9eb7cfd0..c9deeee2 100644 --- a/src/node_operators_binary.h +++ b/src/node_operators_binary.h @@ -26,14 +26,13 @@ struct BinaryNodeOp : public Node { // use df/dx to calc grad backward(); //cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; + //cerr << "orig b_->grad()=" << b_->grad().Debug() << endl; cerr << "TENSOR A:" << endl; calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA); cerr << "TENSOR B:" << endl; calc_numeric_grad(delta, b_->val(), b_->grad(), preCalcGradB); - // redo proper grad - backward(); } @@ -249,11 +248,11 @@ struct CrossEntropyNodeOp : public BinaryNodeOp { } else { probs_.allocate(a_->val().shape(), 0.0); } - thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin()); - Softmax(&probs_); // Safe version of softmax. + + CudnnLogSoftmax(probs_, a_->val()); if(!result_) result_.allocate(a_->val().shape()); - Element(_1 = -_2 * Log(_3), result_, b_->val(), probs_); + Element(_1 = -_2 * _3, result_, b_->val(), probs_); SumRowwise(result_, val_); } @@ -262,22 +261,19 @@ struct CrossEntropyNodeOp : public BinaryNodeOp { // graph. In general the backward functions can skip the computation of // gradients wrt input nodes. void backward() { - // For each row, the first input derivative is given by adj * (p - y), + // We are using logsoftmax for this and cached probs are logs. + // For each row, the first input derivative is given by adj * (exp(p) - y), // where y is the gold label distribution (e.g. one hot vector) and // p is the softmax output (probabilities). - // The second input derivative is -adj*log(p). - if(!result_) - result_.allocate(probs_.shape()); + // The second input derivative is -adj*p. // Compute first input derivative. - Element(_1 = _2 - _3, result_, probs_, b_->val()); - ScaleRowwise(result_, adj_); - Element(_1 += _2, a_->grad(), result_); + Element(_1 += _2 * (Exp(_3) - _4), + a_->grad(), adj_, probs_, b_->val()); // Compute second input derivative. - Element(_1 = -Log(_2), result_, probs_); // @TODO: use a cached log here. - ScaleRowwise(result_, adj_); - Element(_1 += _2, b_->grad(), result_); + Element(_1 -= _2 * _3, b_->grad(), + adj_, probs_); } virtual std::string graphviz() { diff --git a/src/node_operators_unary.h b/src/node_operators_unary.h index c7b13c70..560b4446 100644 --- a/src/node_operators_unary.h +++ b/src/node_operators_unary.h @@ -22,6 +22,7 @@ struct UnaryNodeOp : public Node { // use df/dx to calc grad backward(); + //cerr << "orig a_->val()=" << a_->val().Debug() << endl; //cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA); @@ -284,6 +285,8 @@ struct NegNodeOp : public UnaryNodeOp { void backward() { Element(_1 += -_2, a_->grad(), adj_); + + //std::cerr << "a_->grad=" << a_->grad().Debug() << std::endl; } virtual std::string graphviz() { diff --git a/src/tensor.h b/src/tensor.h index 47b66781..522d705c 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -22,6 +22,8 @@ // SOFTWARE. #include +#include + #include #include #include @@ -77,6 +79,9 @@ class TensorImpl { thrust::device_vector data_; /*< Vector of data that Tensor is managing on GPU. */ size_t tno_; /*< Tensor number */ static size_t tensorCounter; /*< Static counter of created Tensors */ + + // cuDNN stuff + cudnnTensorDescriptor_t cudnnDesc_; public: typedef Float value_type; /*< Tensor value type */ @@ -100,11 +105,20 @@ class TensorImpl { int size = GetTotalSize(shape_); data_.resize(size, value); + + cudnnCreateTensorDescriptor(&cudnnDesc_); + cudnnSetTensor4dDescriptorEx(cudnnDesc_, CUDNN_DATA_FLOAT, + shape_[0], shape_[1], 1, 1, + shape_[1], 1, 1, 1); } TensorImpl(const TensorImpl&) = delete; TensorImpl(TensorImpl&&) = delete; + ~TensorImpl() { + cudnnDestroyTensorDescriptor(cudnnDesc_); + } + /** * @brief Get the i-th element of Tensor vector. * @@ -208,12 +222,6 @@ class TensorImpl { thrust::copy(begin, end, data_.begin()); } - void incr(Float incr) { - for (size_t i = 0; i < data_.size(); ++i) { - data_[i] += incr; - } - } - /** * @brief Copy Tensor's vector from GPU to vector variable on CPU. * @@ -249,6 +257,11 @@ class TensorImpl { } return strm.str(); } + + cudnnTensorDescriptor_t cudnn() { + return cudnnDesc_; + } + }; template @@ -438,11 +451,6 @@ class Tensor { */ void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); - void incr(Float incr) { - pimpl_->incr(incr) -; - } - /** * @brief Copy Tensor's vector from GPU to vector variable on CPU (const). * @@ -494,6 +502,10 @@ class Tensor { TensorView gpu() { return TensorView(*this); } + + cudnnTensorDescriptor_t cudnn() { + return pimpl_->cudnn(); + } }; /** diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index ab07ddac..9f938786 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -29,7 +29,15 @@ static cublasHandle_t create_handle() { cublasCreate(&cublasHandle); return cublasHandle; } + +static cudnnHandle_t create_handle_dnn() { + cudnnHandle_t cudnnHandle; + cudnnCreate(&cudnnHandle); + return cudnnHandle; +} + cublasHandle_t cublasHandle = create_handle(); +cudnnHandle_t cudnnHandle = create_handle_dnn(); __global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val, const int rows, const int cols) { @@ -84,6 +92,60 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { cudaStreamSynchronize(0); } +__global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val, + const int rows, const int cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + extern __shared__ float _share[]; + float* _sum = _share + blockDim.x; + + float* gradRow = grad + j * cols; + const float* adjRow = adj + j * cols; + const float* valRow = val + j * cols; + _sum[threadIdx.x] = 0.0; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + _sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp becaus we chached logsoftmax + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sum[threadIdx.x] += _sum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + for(int tid = 0; tid < cols; tid += blockDim.x){ + int id = tid + threadIdx.x; + if(id < cols) + gradRow[id] += adjRow[id] - _sum[0]; + } + } + } +} + +void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { + // grad and val are both m-by-k matrices, passed as input. + // A weighted average of each row of grad (according to the weights + // specified in val) is computed and subtracted from Out. + // adj is multiplied for each element to get backward step in autodiff + int m = grad.shape()[0]; + int k = grad.shape()[1]; + + int blocks = std::min(MAX_BLOCKS, m); + int threads = std::min(MAX_THREADS, k); + int shared = sizeof(float) * threads * 2; + gLogSoftmaxGrad<<>>(grad.data(), + adj.data(), val.data(), + m, k); + cudaStreamSynchronize(0); +} + __global__ void gSubtractMax(float* out, size_t rows, size_t cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; @@ -346,4 +408,32 @@ void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) { cudaStreamSynchronize(0); } +void CudnnSoftmax(Tensor out, Tensor in) { + float alpha = 1, beta = 0; + cudnnSoftmaxForward(cudnnHandle, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + in.cudnn(), + in.data(), + &beta, + out.cudnn(), + out.data()); + cudaDeviceSynchronize(); +} + +void CudnnLogSoftmax(Tensor out, Tensor in) { + float alpha = 1, beta = 0; + cudnnSoftmaxForward(cudnnHandle, + CUDNN_SOFTMAX_LOG, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + in.cudnn(), + in.data(), + &beta, + out.cudnn(), + out.data()); + cudaDeviceSynchronize(); +} + } \ No newline at end of file diff --git a/src/tensor_operators.h b/src/tensor_operators.h index ab07effe..af0ed59c 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -158,6 +158,10 @@ void SubtractMax(Tensor* Out); void Softmax(Tensor* Out); void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val); +void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val); + +void CudnnSoftmax(Tensor out, Tensor in); +void CudnnLogSoftmax(Tensor out, Tensor in); void Argmax(Tensor* Out, const Tensor* In); diff --git a/src/test_nodes.cu b/src/test_nodes.cu index 996b1149..5f65edb8 100644 --- a/src/test_nodes.cu +++ b/src/test_nodes.cu @@ -55,7 +55,7 @@ int main(int argc, char** argv) // train g.forward(batch_size); //g.backward(); - g.backward_debug(0.00001); + g.backward_debug(0.001); std::cout << g.graphviz() << std::endl;