Implemented fast softmax.

This commit is contained in:
Andre Martins 2016-09-14 08:36:01 +01:00
parent 803a562d4b
commit 0380991393
6 changed files with 93 additions and 2 deletions

View File

@ -171,6 +171,13 @@ inline Expr softmax(Expr a, Args ...args) {
return e / sum(e, args...);
}
template <typename ...Args>
inline Expr softmax_fast(Expr a, Args ...args) {
Expr e = Expr(new SoftmaxNodeOp(a, args...));
return e;
}
// inefficient
template <typename ...Args>
inline Expr mean(Expr a, Args ...args) {

View File

@ -101,6 +101,30 @@ struct TanhNodeOp : public UnaryNodeOp {
}
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a),
args...) { }
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
return shape;
}
void forward() {
// B = softmax(A).
val_ = a_->val();
Softmax(&val_);
}
void backward() {
// TODO
Element(_1 += _2 * Exp(_3),
a_->grad(), adj_, a_->val());
}
};
struct LogNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogNodeOp(Args ...args)

View File

@ -240,6 +240,13 @@ class Tensor {
return pimpl_->Debug();
}
void Print() const {
for (int i = 0; i < size(); ++i) {
std::cerr << (*this)[i] << " ";
}
std::cerr << std::endl;
}
};
}

View File

@ -2,6 +2,53 @@
namespace marian {
// TODO: implement this.
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = softMaxP + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
sp[id] = __expf(sp[id]);
_sum[threadIdx.x] += sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] /= _sum[0];
}
}
}
}
// TODO: implement this.
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta) {
Float alpha = 1.0;

View File

@ -142,6 +142,10 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out);
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta);

View File

@ -15,7 +15,7 @@ int main(int argc, char** argv) {
Expr b = param(shape={1, 10}, name="b0");
auto scores = dot(x, w) + b;
auto lr = softmax(scores, axis=1, name="pred");
auto lr = softmax_fast(scores, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
cerr << "lr=" << lr.Debug() << endl;
@ -40,12 +40,14 @@ int main(int argc, char** argv) {
std::cerr << val << " ";
}
std::cerr << std::endl;
lr.val().Print();
std::cerr << "Log-likelihood: ";
for (auto val : graph.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
graph.val().Print();
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;