mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Implemented gradient of fast softmax.
This commit is contained in:
parent
76cda34544
commit
0421d8504d
@ -119,9 +119,12 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void backward() {
|
void backward() {
|
||||||
// TODO
|
// For each row, the Jacobian times vector is given by:
|
||||||
Element(_1 += _2 * Exp(_3),
|
// J * dy = p .* (dy - avg*1)
|
||||||
a_->grad(), adj_, a_->val());
|
// where avg = p'*dy and p is the softmax output (probabilities).
|
||||||
|
Tensor result = adj_;
|
||||||
|
SubtractMean(&result, val_);
|
||||||
|
Prod(a_->grad(), adj_, result, false, false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2,7 +2,57 @@
|
|||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
// TODO: implement this.
|
__global__ void gSubtractMean(float* out, float* weights,
|
||||||
|
size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if(j < rows) {
|
||||||
|
extern __shared__ float _share[];
|
||||||
|
float* _sum = _share + blockDim.x;
|
||||||
|
float* sp = out + j * cols;
|
||||||
|
float* w = weights + j * cols;
|
||||||
|
_sum[threadIdx.x] = 0.0;
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||||
|
int id = tid + threadIdx.x;
|
||||||
|
if(id < cols) {
|
||||||
|
_sum[threadIdx.x] += w[id] * sp[id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
int len = blockDim.x;
|
||||||
|
while(len != 1) {
|
||||||
|
__syncthreads();
|
||||||
|
int skip = (len + 1) >> 1;
|
||||||
|
if(threadIdx.x < (len >> 1))
|
||||||
|
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
|
||||||
|
len = (len + 1) >> 1;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x){
|
||||||
|
int id = tid + threadIdx.x;
|
||||||
|
if(id < cols)
|
||||||
|
sp[id] -= _sum[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubtractMean(Tensor* Out, Tensor &Weights) {
|
||||||
|
// Out and Weights are both m-by-k matrices, passed as input.
|
||||||
|
// A weighted average of each row of Out (according to the weights
|
||||||
|
// specified in Weights) is computed and subtracted from Out.
|
||||||
|
// Out is both input and output.
|
||||||
|
size_t m = Out->shape()[0];
|
||||||
|
size_t k = Out->shape()[1];
|
||||||
|
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int) m);
|
||||||
|
int threads = std::min(MAX_THREADS, (int) k);
|
||||||
|
int shared = sizeof(float) * threads * 2;
|
||||||
|
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(),
|
||||||
|
m, k);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
int j = bid + blockIdx.x;
|
int j = bid + blockIdx.x;
|
||||||
@ -37,7 +87,6 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: implement this.
|
|
||||||
void Softmax(Tensor* Out) {
|
void Softmax(Tensor* Out) {
|
||||||
size_t m = Out->shape()[0];
|
size_t m = Out->shape()[0];
|
||||||
size_t k = Out->shape()[1];
|
size_t k = Out->shape()[1];
|
||||||
|
@ -142,6 +142,11 @@ void Element(Functor functor,
|
|||||||
cudaStreamSynchronize(0);
|
cudaStreamSynchronize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void gSubtractMean(float* out, float* weights,
|
||||||
|
size_t rows, size_t cols);
|
||||||
|
|
||||||
|
void SubtractMean(Tensor* Out, Tensor &Weights);
|
||||||
|
|
||||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
||||||
|
|
||||||
void Softmax(Tensor* Out);
|
void Softmax(Tensor* Out);
|
||||||
|
Loading…
Reference in New Issue
Block a user