mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Implemented fast softmax.
This commit is contained in:
parent
803a562d4b
commit
0380991393
@ -171,6 +171,13 @@ inline Expr softmax(Expr a, Args ...args) {
|
||||
return e / sum(e, args...);
|
||||
}
|
||||
|
||||
template <typename ...Args>
|
||||
inline Expr softmax_fast(Expr a, Args ...args) {
|
||||
Expr e = Expr(new SoftmaxNodeOp(a, args...));
|
||||
return e;
|
||||
}
|
||||
|
||||
|
||||
// inefficient
|
||||
template <typename ...Args>
|
||||
inline Expr mean(Expr a, Args ...args) {
|
||||
|
@ -101,6 +101,30 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
};
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
val_ = a_->val();
|
||||
Softmax(&val_);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// TODO
|
||||
Element(_1 += _2 * Exp(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
};
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
|
@ -240,6 +240,13 @@ class Tensor {
|
||||
return pimpl_->Debug();
|
||||
}
|
||||
|
||||
void Print() const {
|
||||
for (int i = 0; i < size(); ++i) {
|
||||
std::cerr << (*this)[i] << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -2,6 +2,53 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
// TODO: implement this.
|
||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
extern __shared__ float _share[];
|
||||
float* _sum = _share + blockDim.x;
|
||||
float* sp = softMaxP + j * cols;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
sp[id] = __expf(sp[id]);
|
||||
_sum[threadIdx.x] += sp[id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
int skip = (len + 1) >> 1;
|
||||
if(threadIdx.x < (len >> 1))
|
||||
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x){
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols)
|
||||
sp[id] /= _sum[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement this.
|
||||
void Softmax(Tensor* Out) {
|
||||
size_t m = Out->shape()[0];
|
||||
size_t k = Out->shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int) m);
|
||||
int threads = std::min(MAX_THREADS, (int) k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta) {
|
||||
Float alpha = 1.0;
|
||||
|
@ -142,6 +142,10 @@ void Element(Functor functor,
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
||||
|
||||
void Softmax(Tensor* Out);
|
||||
|
||||
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta);
|
||||
|
||||
|
@ -15,7 +15,7 @@ int main(int argc, char** argv) {
|
||||
Expr b = param(shape={1, 10}, name="b0");
|
||||
|
||||
auto scores = dot(x, w) + b;
|
||||
auto lr = softmax(scores, axis=1, name="pred");
|
||||
auto lr = softmax_fast(scores, axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
cerr << "lr=" << lr.Debug() << endl;
|
||||
|
||||
@ -40,11 +40,13 @@ int main(int argc, char** argv) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
lr.val().Print();
|
||||
std::cerr << "Log-likelihood: ";
|
||||
for (auto val : graph.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
graph.val().Print();
|
||||
|
||||
graph.backward();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user