mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Implemented safe softmax (but doesn't solve the problem yet, we need log-softmax).
This commit is contained in:
parent
eb57df2a3e
commit
f6de1677e1
@ -156,6 +156,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
|||||||
void forward() {
|
void forward() {
|
||||||
// B = softmax(A).
|
// B = softmax(A).
|
||||||
val_ = a_->val();
|
val_ = a_->val();
|
||||||
|
SubtractMax(&val_); // Safe version of softmax.
|
||||||
Softmax(&val_);
|
Softmax(&val_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,4 +95,4 @@ class Adam {
|
|||||||
std::vector<Tensor> vt_;
|
std::vector<Tensor> vt_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -55,6 +55,56 @@ void SubtractMean(Tensor* Out, Tensor &Weights) {
|
|||||||
cudaStreamSynchronize(0);
|
cudaStreamSynchronize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
|
||||||
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
int j = bid + blockIdx.x;
|
||||||
|
if (j < rows) {
|
||||||
|
extern __shared__ float _share[];
|
||||||
|
float* _max = _share + blockDim.x;
|
||||||
|
float* sp = out + j * cols;
|
||||||
|
_max[threadIdx.x] = sp[threadIdx.x];
|
||||||
|
for(int tid = 1; tid < cols; tid += blockDim.x) {
|
||||||
|
int id = tid + threadIdx.x;
|
||||||
|
if (id < cols) {
|
||||||
|
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
int len = blockDim.x;
|
||||||
|
while(len != 1) {
|
||||||
|
__syncthreads();
|
||||||
|
int skip = (len + 1) >> 1;
|
||||||
|
if (threadIdx.x < (len >> 1)) {
|
||||||
|
if (_max[threadIdx.x + skip] > _max[threadIdx.x]) {
|
||||||
|
_max[threadIdx.x] = _max[threadIdx.x + skip];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
len = (len + 1) >> 1;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for(int tid = 0; tid < cols; tid += blockDim.x){
|
||||||
|
int id = tid + threadIdx.x;
|
||||||
|
if(id < cols)
|
||||||
|
sp[id] -= _max[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SubtractMax(Tensor* Out) {
|
||||||
|
// Out is a m-by-k matrix, passed as input.
|
||||||
|
// The max element of each row of Out is computed and subtracted from Out.
|
||||||
|
// Out is both input and output.
|
||||||
|
size_t m = Out->shape()[0];
|
||||||
|
size_t k = Out->shape()[1];
|
||||||
|
|
||||||
|
int blocks = std::min(MAX_BLOCKS, (int) m);
|
||||||
|
int threads = std::min(MAX_THREADS, (int) k);
|
||||||
|
int shared = sizeof(float) * threads * 2;
|
||||||
|
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
||||||
|
cudaStreamSynchronize(0);
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////
|
||||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||||
|
@ -147,6 +147,10 @@ __global__ void gSubtractMean(float* out, float* weights,
|
|||||||
|
|
||||||
void SubtractMean(Tensor* Out, Tensor &Weights);
|
void SubtractMean(Tensor* Out, Tensor &Weights);
|
||||||
|
|
||||||
|
__global__ void gSubtractMax(float* out, size_t rows, size_t cols);
|
||||||
|
|
||||||
|
void SubtractMax(Tensor* Out);
|
||||||
|
|
||||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
||||||
|
|
||||||
void Softmax(Tensor* Out);
|
void Softmax(Tensor* Out);
|
||||||
|
Loading…
Reference in New Issue
Block a user