Moved max-subtraction (safe softmax) inside the Softmax implementation.

This commit is contained in:
Andre Martins 2016-09-17 23:09:35 +01:00
parent 089f76751b
commit 20c80e4bbc
2 changed files with 4 additions and 5 deletions

View File

@ -144,10 +144,6 @@ struct TanhNodeOp : public UnaryNodeOp {
};
// @TODO, make this numerically safe(r):
// softmax(X) = softmax_safe(X - max(X, axis=1))
// Probably best to do this directly in Softmax
// function.
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)
@ -156,7 +152,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
void forward() {
// B = softmax(A).
val_ = a_->val();
SubtractMax(&val_); // Safe version of softmax.
// Safe version of softmax.
Softmax(&val_);
}

View File

@ -155,6 +155,9 @@ void Softmax(Tensor* Out) {
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}