diff --git a/src/graph_operators.h b/src/graph_operators.h index 5a12f807..972456db 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -14,7 +14,7 @@ struct InputNode : public Node { !Has(keywords::lazy_shape), "Data items require shape information"); } - + virtual void setVal(Tensor t) { val_ = t; shape_ = t.shape(); @@ -33,7 +33,7 @@ struct ConstantNode : public Node { !Has(keywords::lazy_shape), "Constant items require shape information"); } - + void forward() {} void backward() {} }; @@ -47,23 +47,23 @@ struct ParamNode : public Node { UTIL_THROW_IF2(!Has(keywords::shape) && !Has(keywords::lazy_shape), "Param items require shape information"); - } - + } + void forward() {} void backward() {} - + virtual void allocate(size_t batchSize) { val_.allocate(shape_); init_(val_); } - + private: std::function init_; }; struct UnaryNodeOp : public Node { ChainPtr a_; - + template UnaryNodeOp(ChainPtr a, Args ...args) : Node(args...), a_(a) {} @@ -73,15 +73,15 @@ struct SigmoidNodeOp : public UnaryNodeOp { template SigmoidNodeOp(Args ...args) : UnaryNodeOp(args...) { } - + void forward() { Element(_1 = Sigma(_2), val_, a_->val()); } - + void backward() { - Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)), - a_->grad(), adj_, a_->val()); + Element(_1 += _2 * _3 * (1 - _3), + a_->grad(), adj_, val_); } }; @@ -89,15 +89,15 @@ struct TanhNodeOp : public UnaryNodeOp { template TanhNodeOp(Args ...args) : UnaryNodeOp(args...) { } - + void forward() { Element(_1 = Tanh(_2), val_, a_->val()); } - + void backward() { - Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)), - a_->grad(), adj_, a_->val()); + Element(_1 += _2 * (1 - _3 * _3), + a_->grad(), adj_, val_); } }; @@ -106,7 +106,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp { SoftmaxNodeOp(ChainPtr a, Args ...args) : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } - Shape newShape(ChainPtr a) { Shape shape = a->shape(); return shape; @@ -117,11 +116,16 @@ struct SoftmaxNodeOp : public UnaryNodeOp { val_ = a_->val(); Softmax(&val_); } - + void backward() { - // TODO - Element(_1 += _2 * Exp(_3), - a_->grad(), adj_, a_->val()); + // For each row, the Jacobian times vector is given by: + // J * dy = p .* (dy - avg*1) + // where avg = p'*dy and p is the softmax output (probabilities). + Tensor result = adj_; + SubtractMean(&result, val_); + // beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C + // to sum gradients from different graph parts. + Prod(a_->grad(), adj_, result, false, false, 1.0); } }; @@ -129,11 +133,11 @@ struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args) : UnaryNodeOp(args...) {} - + void forward() { Element(_1 = Log(_2), val_, a_->val()); } - + void backward() { Element(_1 += _2 * 1.f / _3, a_->grad(), adj_, a_->val()); @@ -145,7 +149,7 @@ struct ExpNodeOp : public UnaryNodeOp { ExpNodeOp(ChainPtr a, Args ...args) : UnaryNodeOp(a, keywords::shape=newShape(a), args...) { } - + Shape newShape(ChainPtr a) { Shape shape = a->shape(); return shape; @@ -154,7 +158,7 @@ struct ExpNodeOp : public UnaryNodeOp { void forward() { Element(_1 = Exp(_2), val_, a_->val()); } - + void backward() { Element(_1 += _2 * Exp(_3), a_->grad(), adj_, a_->val()); @@ -165,11 +169,11 @@ struct NegNodeOp : public UnaryNodeOp { template NegNodeOp(Args ...args) : UnaryNodeOp(args...) { } - + void forward() { Element(_1 = -_2, val_, a_->val()); } - + void backward() { Element(_1 += -_2, a_->grad(), adj_); } @@ -194,7 +198,7 @@ struct DotNodeOp : public BinaryNodeOp { : BinaryNodeOp(a, b, keywords::shape=newShape(a,b), args...) { } - + Shape newShape(ChainPtr a, ChainPtr b) { Shape shape1 = a->shape(); Shape shape2 = b->shape(); @@ -203,12 +207,12 @@ struct DotNodeOp : public BinaryNodeOp { shape1[1] = shape2[1]; return shape1; } - + void forward() { // C = A*B Prod(val_, a_->val(), b_->val(), false, false); } - + void backward() { // D is the adjoint, the matrix of derivatives // df/dA += D*B.T diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index 2d1d541d..2aa96331 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -2,7 +2,57 @@ namespace marian { -// TODO: implement this. +__global__ void gSubtractMean(float* out, float* weights, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + extern __shared__ float _share[]; + float* _sum = _share + blockDim.x; + float* sp = out + j * cols; + float* w = weights + j * cols; + _sum[threadIdx.x] = 0.0; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + _sum[threadIdx.x] += w[id] * sp[id]; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sum[threadIdx.x] += _sum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + for(int tid = 0; tid < cols; tid += blockDim.x){ + int id = tid + threadIdx.x; + if(id < cols) + sp[id] -= _sum[0]; + } + } + } +} + +void SubtractMean(Tensor* Out, Tensor &Weights) { + // Out and Weights are both m-by-k matrices, passed as input. + // A weighted average of each row of Out (according to the weights + // specified in Weights) is computed and subtracted from Out. + // Out is both input and output. + size_t m = Out->shape()[0]; + size_t k = Out->shape()[1]; + + int blocks = std::min(MAX_BLOCKS, (int) m); + int threads = std::min(MAX_THREADS, (int) k); + int shared = sizeof(float) * threads * 2; + gSubtractMean<<>>(Out->data(), Weights.data(), + m, k); + cudaStreamSynchronize(0); +} + __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; @@ -37,7 +87,6 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { } } -// TODO: implement this. void Softmax(Tensor* Out) { size_t m = Out->shape()[0]; size_t k = Out->shape()[1]; diff --git a/src/tensor_operators.h b/src/tensor_operators.h index a0c30104..03d754e3 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -142,6 +142,11 @@ void Element(Functor functor, cudaStreamSynchronize(0); } +__global__ void gSubtractMean(float* out, float* weights, + size_t rows, size_t cols); + +void SubtractMean(Tensor* Out, Tensor &Weights); + __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols); void Softmax(Tensor* Out);