Implemented gradient of fast softmax.

This commit is contained in:
Andre Martins 2016-09-14 14:30:40 +01:00
parent 76cda34544
commit 0421d8504d
3 changed files with 62 additions and 5 deletions

View File

@ -119,9 +119,12 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
}
void backward() {
// TODO
Element(_1 += _2 * Exp(_3),
a_->grad(), adj_, a_->val());
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
Tensor result = adj_;
SubtractMean(&result, val_);
Prod(a_->grad(), adj_, result, false, false);
}
};

View File

@ -2,7 +2,57 @@
namespace marian {
// TODO: implement this.
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = out + j * cols;
float* w = weights + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += w[id] * sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _sum[0];
}
}
}
}
void SubtractMean(Tensor* Out, Tensor &Weights) {
// Out and Weights are both m-by-k matrices, passed as input.
// A weighted average of each row of Out (according to the weights
// specified in Weights) is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(),
m, k);
cudaStreamSynchronize(0);
}
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
@ -37,7 +87,6 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
}
}
// TODO: implement this.
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];

View File

@ -142,6 +142,11 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols);
void SubtractMean(Tensor* Out, Tensor &Weights);
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out);