From 20c80e4bbce2dcc9a7a632eef3328fa039f24d17 Mon Sep 17 00:00:00 2001 From: Andre Martins Date: Sat, 17 Sep 2016 23:09:35 +0100 Subject: [PATCH] Moved max-subtraction (safe softmax) inside the Softmax implementation. --- src/node_operators.h | 6 +----- src/tensor_operators.cu | 3 +++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/node_operators.h b/src/node_operators.h index db2031e9..507d967d 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -144,10 +144,6 @@ struct TanhNodeOp : public UnaryNodeOp { }; -// @TODO, make this numerically safe(r): -// softmax(X) = softmax_safe(X - max(X, axis=1)) -// Probably best to do this directly in Softmax -// function. struct SoftmaxNodeOp : public UnaryNodeOp { template SoftmaxNodeOp(Args ...args) @@ -156,7 +152,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { void forward() { // B = softmax(A). val_ = a_->val(); - SubtractMax(&val_); // Safe version of softmax. + // Safe version of softmax. Softmax(&val_); } diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index ad30d051..1d0f7e2f 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -155,6 +155,9 @@ void Softmax(Tensor* Out) { int blocks = std::min(MAX_BLOCKS, (int) m); int threads = std::min(MAX_THREADS, (int) k); int shared = sizeof(float) * threads * 2; + // Subtract the max rowwise for numerical stability (safe softmax). + gSubtractMax<<>>(Out->data(), m, k); + cudaStreamSynchronize(0); gSoftMax<<>>(Out->data(), m, k); cudaStreamSynchronize(0); }