mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
cudnn softmax, log-softmax
This commit is contained in:
parent
828a0db8bc
commit
e10a95487c
@ -19,6 +19,11 @@ cuda_add_executable(
|
||||
test.cu
|
||||
)
|
||||
|
||||
cuda_add_executable(
|
||||
test_cudnn
|
||||
cudnn.cu
|
||||
)
|
||||
|
||||
cuda_add_executable(
|
||||
mnist_benchmark
|
||||
mnist_benchmark.cu
|
||||
@ -40,12 +45,13 @@ cuda_add_executable(
|
||||
)
|
||||
|
||||
target_link_libraries(marian marian_lib)
|
||||
target_link_libraries(test_cudnn marian_lib)
|
||||
target_link_libraries(mnist_benchmark marian_lib)
|
||||
target_link_libraries(validate_mnist_batch marian_lib)
|
||||
target_link_libraries(validate_encoder_decoder marian_lib)
|
||||
target_link_libraries(test_nodes marian_lib)
|
||||
|
||||
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes)
|
||||
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes test_cudnn)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
149
src/cudnn.cu
Normal file
149
src/cudnn.cu
Normal file
@ -0,0 +1,149 @@
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "tensor.h"
|
||||
#include "tensor_operators.h"
|
||||
#include "param_initializers.h"
|
||||
|
||||
using namespace marian;
|
||||
|
||||
void CudnnSoftmaxForward(cudnnHandle_t cudnnHandle,
|
||||
Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
cudnnSoftmaxForward(cudnnHandle,
|
||||
CUDNN_SOFTMAX_LOG,
|
||||
CUDNN_SOFTMAX_MODE_CHANNEL,
|
||||
&alpha,
|
||||
in.cudnn(),
|
||||
in.data(),
|
||||
&beta,
|
||||
out.cudnn(),
|
||||
out.data());
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void CudnnSoftmaxBackward(cudnnHandle_t cudnnHandle,
|
||||
Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
cudnnSoftmaxBackward(cudnnHandle,
|
||||
CUDNN_SOFTMAX_LOG,
|
||||
CUDNN_SOFTMAX_MODE_CHANNEL,
|
||||
&alpha,
|
||||
in.cudnn(),
|
||||
in.data(),
|
||||
out.cudnn(),
|
||||
out.data(),
|
||||
&beta,
|
||||
out.cudnn(),
|
||||
out.data());
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
int main() {
|
||||
cudnnHandle_t cudnnHandle;
|
||||
cudnnCreate(&cudnnHandle);
|
||||
|
||||
int d = 10;
|
||||
|
||||
Tensor in({d, d});
|
||||
Tensor out({d, d});
|
||||
Tensor grad({d, d});
|
||||
Tensor adj({d, d}, 1);
|
||||
|
||||
auto f = uniform(-5, 5);
|
||||
f(in);
|
||||
|
||||
std::cerr << in.Debug() << std::endl;
|
||||
|
||||
{
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 0; i < 1; ++i) {
|
||||
CudnnSoftmaxForward(cudnnHandle, out, in);
|
||||
std::cerr << out.Debug() << std::endl;
|
||||
CudnnSoftmaxBackward(cudnnHandle, grad, in);
|
||||
std::cerr << grad.Debug() << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 0; i < 1; ++i) {
|
||||
Element(_1 = _2, out, in);
|
||||
Softmax(&out);
|
||||
std::cerr << out.Debug() << std::endl;
|
||||
SoftmaxGrad(grad, adj, out);
|
||||
std::cerr << grad.Debug() << std::endl;
|
||||
}
|
||||
//std::cerr << grad.Debug() << std::endl;
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
}
|
||||
|
||||
|
||||
//// Copy back
|
||||
//float *result = (float *) malloc(m * c * sizeof(float));
|
||||
//cudaMemcpy(result, d_softmaxData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
//cudaDeviceSynchronize();
|
||||
//
|
||||
//// Log
|
||||
//printf("SOFTMAX:\n");
|
||||
//printMatrix(result, c, m);
|
||||
//
|
||||
//// Try backward
|
||||
//cudnnTensorDescriptor_t diffTensorDesc;
|
||||
//cudnnCreateTensorDescriptor(&diffTensorDesc);
|
||||
//cudnnSetTensor4dDescriptor(diffTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
|
||||
// m, c, 1, 1);
|
||||
//
|
||||
//float *d_gradData;
|
||||
//cudaMalloc((void**) &d_gradData, m * c * sizeof(float));
|
||||
//
|
||||
//float *diffData = makeDiffData(m, c);
|
||||
//float *d_diffData;
|
||||
//cudaMalloc((void**) &d_diffData, m * c * sizeof(float));
|
||||
//cudaMemcpy(d_diffData, diffData, m * c * sizeof(float), cudaMemcpyHostToDevice);
|
||||
//cudaDeviceSynchronize();
|
||||
//
|
||||
//cudnnSoftmaxBackward(cudnnHandle,
|
||||
// CUDNN_SOFTMAX_ACCURATE,
|
||||
// CUDNN_SOFTMAX_MODE_CHANNEL,
|
||||
// &alpha,
|
||||
// srcTensorDesc,
|
||||
// d_softmaxData,
|
||||
// diffTensorDesc,
|
||||
// d_diffData,
|
||||
// &beta,
|
||||
// sftTensorDesc,
|
||||
// d_gradData);
|
||||
//cudaDeviceSynchronize();
|
||||
//
|
||||
//// Copy back
|
||||
//float *result_backward = (float *) malloc(m * c * sizeof(float));
|
||||
//cudaMemcpy(result_backward, d_gradData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
//cudaDeviceSynchronize();
|
||||
//
|
||||
//// Log
|
||||
//printf("GRADIENT:\n");
|
||||
//printMatrix(result_backward, c, m);
|
||||
//
|
||||
//// Destruct
|
||||
//free(result);
|
||||
//free(diffData);
|
||||
//free(result_backward);
|
||||
//free(fcLayer);
|
||||
//
|
||||
//cudnnDestroyTensorDescriptor(srcTensorDesc);
|
||||
//cudnnDestroyTensorDescriptor(sftTensorDesc);
|
||||
//cudnnDestroyTensorDescriptor(diffTensorDesc);
|
||||
//cudaFree(d_fcLayer);
|
||||
//cudaFree(d_softmaxData);
|
||||
//cudaFree(d_gradData);
|
||||
//cudaFree(d_diffData);
|
||||
cudnnDestroy(cudnnHandle);
|
||||
return 0;
|
||||
}
|
@ -122,7 +122,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
boost::timer::cpu_timer total;
|
||||
Adam opt(0.0002);
|
||||
for(int i = 1; i <= 30; ++i) {
|
||||
for(int i = 1; i <= 50; ++i) {
|
||||
boost::timer::cpu_timer timer;
|
||||
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
float cost = 0;
|
||||
|
@ -249,11 +249,11 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
} else {
|
||||
probs_.allocate(a_->val().shape(), 0.0);
|
||||
}
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
|
||||
Softmax(&probs_); // Safe version of softmax.
|
||||
|
||||
CudnnLogSoftmax(probs_, a_->val());
|
||||
if(!result_)
|
||||
result_.allocate(a_->val().shape());
|
||||
Element(_1 = -_2 * Log(_3), result_, b_->val(), probs_);
|
||||
Element(_1 = -_2 * _3, result_, b_->val(), probs_);
|
||||
SumRowwise(result_, val_);
|
||||
}
|
||||
|
||||
@ -262,22 +262,19 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
// graph. In general the backward functions can skip the computation of
|
||||
// gradients wrt input nodes.
|
||||
void backward() {
|
||||
// For each row, the first input derivative is given by adj * (p - y),
|
||||
// We are using logsoftmax for this and cached probs are logs.
|
||||
// For each row, the first input derivative is given by adj * (exp(p) - y),
|
||||
// where y is the gold label distribution (e.g. one hot vector) and
|
||||
// p is the softmax output (probabilities).
|
||||
// The second input derivative is -adj*log(p).
|
||||
if(!result_)
|
||||
result_.allocate(probs_.shape());
|
||||
// The second input derivative is -adj*p.
|
||||
|
||||
// Compute first input derivative.
|
||||
Element(_1 = _2 - _3, result_, probs_, b_->val());
|
||||
ScaleRowwise(result_, adj_);
|
||||
Element(_1 += _2, a_->grad(), result_);
|
||||
Element(_1 += _2 * (Exp(_3) - _4),
|
||||
a_->grad(), adj_, probs_, b_->val());
|
||||
|
||||
// Compute second input derivative.
|
||||
Element(_1 = -Log(_2), result_, probs_); // @TODO: use a cached log here.
|
||||
ScaleRowwise(result_, adj_);
|
||||
Element(_1 += _2, b_->grad(), result_);
|
||||
Element(_1 -= _2 * _3, b_->grad(),
|
||||
adj_, probs_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
|
23
src/tensor.h
23
src/tensor.h
@ -22,6 +22,8 @@
|
||||
// SOFTWARE.
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <numeric>
|
||||
@ -77,6 +79,9 @@ class TensorImpl {
|
||||
thrust::device_vector<Float> data_; /*< Vector of data that Tensor is managing on GPU. */
|
||||
size_t tno_; /*< Tensor number */
|
||||
static size_t tensorCounter; /*< Static counter of created Tensors */
|
||||
|
||||
// cuDNN stuff
|
||||
cudnnTensorDescriptor_t cudnnDesc_;
|
||||
|
||||
public:
|
||||
typedef Float value_type; /*< Tensor value type */
|
||||
@ -100,11 +105,20 @@ class TensorImpl {
|
||||
|
||||
int size = GetTotalSize(shape_);
|
||||
data_.resize(size, value);
|
||||
|
||||
cudnnCreateTensorDescriptor(&cudnnDesc_);
|
||||
cudnnSetTensor4dDescriptorEx(cudnnDesc_, CUDNN_DATA_FLOAT,
|
||||
shape_[0], shape_[1], 1, 1,
|
||||
shape_[1], 1, 1, 1);
|
||||
}
|
||||
|
||||
TensorImpl(const TensorImpl&) = delete;
|
||||
TensorImpl(TensorImpl&&) = delete;
|
||||
|
||||
~TensorImpl() {
|
||||
cudnnDestroyTensorDescriptor(cudnnDesc_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the i-th element of Tensor vector.
|
||||
*
|
||||
@ -249,6 +263,11 @@ class TensorImpl {
|
||||
}
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
cudnnTensorDescriptor_t cudnn() {
|
||||
return cudnnDesc_;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
@ -494,6 +513,10 @@ class Tensor {
|
||||
TensorView gpu() {
|
||||
return TensorView(*this);
|
||||
}
|
||||
|
||||
cudnnTensorDescriptor_t cudnn() {
|
||||
return pimpl_->cudnn();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -29,7 +29,15 @@ static cublasHandle_t create_handle() {
|
||||
cublasCreate(&cublasHandle);
|
||||
return cublasHandle;
|
||||
}
|
||||
|
||||
static cudnnHandle_t create_handle_dnn() {
|
||||
cudnnHandle_t cudnnHandle;
|
||||
cudnnCreate(&cudnnHandle);
|
||||
return cudnnHandle;
|
||||
}
|
||||
|
||||
cublasHandle_t cublasHandle = create_handle();
|
||||
cudnnHandle_t cudnnHandle = create_handle_dnn();
|
||||
|
||||
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
const int rows, const int cols) {
|
||||
@ -84,6 +92,60 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
const int rows, const int cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
extern __shared__ float _share[];
|
||||
float* _sum = _share + blockDim.x;
|
||||
|
||||
float* gradRow = grad + j * cols;
|
||||
const float* adjRow = adj + j * cols;
|
||||
const float* valRow = val + j * cols;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp becaus we chached logsoftmax
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
int skip = (len + 1) >> 1;
|
||||
if(threadIdx.x < (len >> 1))
|
||||
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x){
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols)
|
||||
gradRow[id] += adjRow[id] - _sum[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
// grad and val are both m-by-k matrices, passed as input.
|
||||
// A weighted average of each row of grad (according to the weights
|
||||
// specified in val) is computed and subtracted from Out.
|
||||
// adj is multiplied for each element to get backward step in autodiff
|
||||
int m = grad.shape()[0];
|
||||
int k = grad.shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
gLogSoftmaxGrad<<<blocks, threads, shared>>>(grad.data(),
|
||||
adj.data(), val.data(),
|
||||
m, k);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
@ -346,4 +408,32 @@ void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
void CudnnSoftmax(Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
cudnnSoftmaxForward(cudnnHandle,
|
||||
CUDNN_SOFTMAX_ACCURATE,
|
||||
CUDNN_SOFTMAX_MODE_CHANNEL,
|
||||
&alpha,
|
||||
in.cudnn(),
|
||||
in.data(),
|
||||
&beta,
|
||||
out.cudnn(),
|
||||
out.data());
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void CudnnLogSoftmax(Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
cudnnSoftmaxForward(cudnnHandle,
|
||||
CUDNN_SOFTMAX_LOG,
|
||||
CUDNN_SOFTMAX_MODE_CHANNEL,
|
||||
&alpha,
|
||||
in.cudnn(),
|
||||
in.data(),
|
||||
&beta,
|
||||
out.cudnn(),
|
||||
out.data());
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
}
|
@ -158,6 +158,10 @@ void SubtractMax(Tensor* Out);
|
||||
void Softmax(Tensor* Out);
|
||||
|
||||
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
|
||||
void CudnnSoftmax(Tensor out, Tensor in);
|
||||
void CudnnLogSoftmax(Tensor out, Tensor in);
|
||||
|
||||
void Argmax(Tensor* Out, const Tensor* In);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user