Merge branch 'master' of github.com:emjotde/Marian

This commit is contained in:
Lane Schwartz 2016-09-21 12:24:25 -05:00
commit 4579f059a0
8 changed files with 216 additions and 327 deletions

View File

@ -11,47 +11,11 @@
using namespace marian;
void CudnnSoftmaxForward(cudnnHandle_t cudnnHandle,
Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnSoftmaxBackward(cudnnHandle_t cudnnHandle,
Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
out.cudnn(),
out.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
int main() {
cudnnHandle_t cudnnHandle;
cudnnCreate(&cudnnHandle);
int d = 10;
int d = 4;
Tensor in({d, d});
Tensor out({d, d});
Tensor grad({d, d});
Tensor adj({d, d}, 1);
auto f = uniform(-5, 5);
@ -62,88 +26,28 @@ int main() {
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
CudnnSoftmaxForward(cudnnHandle, out, in);
std::cerr << out.Debug() << std::endl;
CudnnSoftmaxBackward(cudnnHandle, grad, in);
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
CudnnLogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Element(_1 = _2, out, in);
Softmax(&out);
std::cerr << out.Debug() << std::endl;
SoftmaxGrad(grad, adj, out);
std::cerr << grad.Debug() << std::endl;
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
LogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
//std::cerr << grad.Debug() << std::endl;
std::cerr << timer.format(5, "%ws") << std::endl;
}
//// Copy back
//float *result = (float *) malloc(m * c * sizeof(float));
//cudaMemcpy(result, d_softmaxData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
//cudaDeviceSynchronize();
//
//// Log
//printf("SOFTMAX:\n");
//printMatrix(result, c, m);
//
//// Try backward
//cudnnTensorDescriptor_t diffTensorDesc;
//cudnnCreateTensorDescriptor(&diffTensorDesc);
//cudnnSetTensor4dDescriptor(diffTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
// m, c, 1, 1);
//
//float *d_gradData;
//cudaMalloc((void**) &d_gradData, m * c * sizeof(float));
//
//float *diffData = makeDiffData(m, c);
//float *d_diffData;
//cudaMalloc((void**) &d_diffData, m * c * sizeof(float));
//cudaMemcpy(d_diffData, diffData, m * c * sizeof(float), cudaMemcpyHostToDevice);
//cudaDeviceSynchronize();
//
//cudnnSoftmaxBackward(cudnnHandle,
// CUDNN_SOFTMAX_ACCURATE,
// CUDNN_SOFTMAX_MODE_CHANNEL,
// &alpha,
// srcTensorDesc,
// d_softmaxData,
// diffTensorDesc,
// d_diffData,
// &beta,
// sftTensorDesc,
// d_gradData);
//cudaDeviceSynchronize();
//
//// Copy back
//float *result_backward = (float *) malloc(m * c * sizeof(float));
//cudaMemcpy(result_backward, d_gradData, m * c * sizeof(float), cudaMemcpyDeviceToHost);
//cudaDeviceSynchronize();
//
//// Log
//printf("GRADIENT:\n");
//printMatrix(result_backward, c, m);
//
//// Destruct
//free(result);
//free(diffData);
//free(result_backward);
//free(fcLayer);
//
//cudnnDestroyTensorDescriptor(srcTensorDesc);
//cudnnDestroyTensorDescriptor(sftTensorDesc);
//cudnnDestroyTensorDescriptor(diffTensorDesc);
//cudaFree(d_fcLayer);
//cudaFree(d_softmaxData);
//cudaFree(d_gradData);
//cudaFree(d_diffData);
cudnnDestroy(cudnnHandle);
return 0;
}

View File

@ -58,6 +58,10 @@ Expr softmax(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}
Expr logsoftmax(Expr a) {
return Expr(a.graph(), new LogSoftmaxNodeOp(a));
}
Expr argmax(Expr a) {
return Expr(a.graph(), new ArgmaxNodeOp(a));
}

View File

@ -106,6 +106,8 @@ Expr softmax_slow(Expr a, Args ...args) {
Expr softmax(Expr a);
Expr logsoftmax(Expr a);
Expr argmax(Expr a);
// inefficient

View File

@ -45,8 +45,8 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
auto cost = mean(cross_entropy(scores, y), axis=0);
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
//auto cost = mean(cross_entropy(scores, y), axis=0);
auto cost = mean(-sum(y * logsoftmax(scores), axis=1), axis=0);
auto costreg = named(
cost, "cost"
);
@ -115,7 +115,7 @@ int main(int argc, char** argv) {
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
std::cout << g.graphviz() << std::endl;
//std::cout << g.graphviz() << std::endl;
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});

View File

@ -239,10 +239,10 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// We're caching the logsoftmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
// C = -dot(B, logsoftmax(A)).
if (probs_) {
probs_.set(0.0);
} else {

View File

@ -22,7 +22,6 @@ struct UnaryNodeOp : public Node {
// use df/dx to calc grad
backward();
//cerr << "orig a_->val()=" << a_->val().Debug() << endl;
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
@ -167,10 +166,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
: UnaryNodeOp(args...) { }
void forward() {
// B = softmax(A).
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax.
Softmax(&val_);
CudnnSoftmax(val_, a_->val());
}
void backward() {
@ -196,6 +192,33 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
};
};
struct LogSoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogSoftmaxNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
CudnnLogSoftmax(val_, a_->val());
}
void backward() {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
CudnnLogSoftmaxGrad(a_->grad(), adj_, val_);
//LogSoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("log-softmax")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ArgmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxNodeOp(ChainPtr a, Args ...args)
@ -285,8 +308,6 @@ struct NegNodeOp : public UnaryNodeOp {
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
//std::cerr << "a_->grad=" << a_->grad().Debug() << std::endl;
}
virtual std::string graphviz() {

View File

@ -39,6 +39,166 @@ static cudnnHandle_t create_handle_dnn() {
cublasHandle_t cublasHandle = create_handle();
cudnnHandle_t cudnnHandle = create_handle_dnn();
void CudnnSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
val.cudnn(),
val.data(),
adj.cudnn(),
adj.data(),
&beta,
grad.cudnn(),
grad.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
val.cudnn(),
val.data(),
adj.cudnn(),
adj.data(),
&beta,
grad.cudnn(),
grad.data());
cudaDeviceSynchronize();
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if (threadIdx.x < (len >> 1)) {
if (_max[threadIdx.x + skip] > _max[threadIdx.x]) {
_max[threadIdx.x] = _max[threadIdx.x + skip];
}
}
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = softMaxP + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
sp[id] = __expf(sp[id]);
_sum[threadIdx.x] += sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] /= _sum[0];
}
}
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
@ -107,7 +267,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp becaus we chached logsoftmax
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp because we chached logsoftmax
}
}
__syncthreads();
@ -146,158 +306,6 @@ void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaStreamSynchronize(0);
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if (threadIdx.x < (len >> 1)) {
if (_max[threadIdx.x + skip] > _max[threadIdx.x]) {
_max[threadIdx.x] = _max[threadIdx.x + skip];
}
}
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
//template <class T>
//__global__ void gClipNorm(T t) {
// int rows = t.rows();
// int cols = t.cols();
//
// for(int bid = 0; bid < rows; bid += gridDim.x) {
// int i = bid + blockIdx.x;
// if(i < rows) {
// extern __shared__ float _share[];
// float* _sum = _share + blockDim.x;
// _sum[threadIdx.x] = 0.0;
// for(int tid = 0; tid < cols; tid += blockDim.x) {
// int j = tid + threadIdx.x;
// if(j < cols)
// _sum[threadIdx.x] += powf(t(i,j), 2.0f);
// }
// __syncthreads();
// int len = blockDim.x;
// while(len != 1) {
// __syncthreads();
// int skip = (len + 1) >> 1;
// if(threadIdx.x < (len >> 1))
// _sum[threadIdx.x] += _sum[threadIdx.x + skip];
// len = (len + 1) >> 1;
// }
// __syncthreads();
// float total = 0;
// if(j == 0) {
// for()
// }
// for(int tid = 0; tid < cols; tid += blockDim.x){
// int j = tid + threadIdx.x;
// if(j < cols)
// sp[j] /= _sum[0];
// }
// }
// }
//}
//
//void ClipNorm(Tensor out, float threshold);
// size_t m = out.shape()[0];
// size_t k = out.shape()[1];
//
// int blocks = std::min(MAX_BLOCKS, (int) m);
// int threads = std::min(MAX_THREADS, (int) k);
// int shared = sizeof(float) * threads * 2;
// gClipNorm<<<blocks, threads, shared>>>(out.gpu());
// cudaStreamSynchronize(0);
//}
///////////////////////////////////////////////////////
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = softMaxP + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
sp[id] = __expf(sp[id]);
_sum[threadIdx.x] += sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] /= _sum[0];
}
}
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols) {
size_t row = blockIdx.x;
@ -382,58 +390,5 @@ Tensor SumRowwise(const Tensor A, Tensor result) {
return temp;
}
// @TODO: replace this by something else when broadcast elementwise operations
// are ready.
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) rowOut[i] *= scalingFactors[j];
}
}
}
}
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
Float* d_out = Out.data();
const Float* d_in = ScalingFactors.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
void CudnnSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
}

View File

@ -161,7 +161,10 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void CudnnSoftmax(Tensor out, Tensor in);
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void CudnnLogSoftmax(Tensor out, Tensor in);
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void Argmax(Tensor* Out, const Tensor* In);