mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 09:40:15 +03:00
Moved max-subtraction (safe softmax) inside the Softmax implementation.
This commit is contained in:
parent
089f76751b
commit
20c80e4bbc
@ -144,10 +144,6 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
// @TODO, make this numerically safe(r):
|
||||
// softmax(X) = softmax_safe(X - max(X, axis=1))
|
||||
// Probably best to do this directly in Softmax
|
||||
// function.
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(Args ...args)
|
||||
@ -156,7 +152,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
val_ = a_->val();
|
||||
SubtractMax(&val_); // Safe version of softmax.
|
||||
// Safe version of softmax.
|
||||
Softmax(&val_);
|
||||
}
|
||||
|
||||
|
@ -155,6 +155,9 @@ void Softmax(Tensor* Out) {
|
||||
int blocks = std::min(MAX_BLOCKS, (int) m);
|
||||
int threads = std::min(MAX_THREADS, (int) k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
// Subtract the max rowwise for numerical stability (safe softmax).
|
||||
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
||||
cudaStreamSynchronize(0);
|
||||
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user