diff --git a/src/expression_operators.cu b/src/expression_operators.cu index 59c1c52d..19cb1fa9 100644 --- a/src/expression_operators.cu +++ b/src/expression_operators.cu @@ -125,4 +125,8 @@ Expr dot(Expr a, Expr b) { return Expr(a.graph(), new DotNodeOp(a, b)); } +Expr cross_entropy(Expr a, Expr b) { + return Expr(a.graph(), new CrossEntropyNodeOp(a, b)); +} + } diff --git a/src/expression_operators.h b/src/expression_operators.h index 6a9b4e53..8da89824 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -112,4 +112,6 @@ inline Expr mean(Expr a, Args ...args) { } } + Expr cross_entropy(Expr a, Expr b); + } diff --git a/src/mnist_benchmark.cu b/src/mnist_benchmark.cu index 5f0bc705..cbd0ebd6 100644 --- a/src/mnist_benchmark.cu +++ b/src/mnist_benchmark.cu @@ -42,13 +42,12 @@ ExpressionGraph build_graph(const std::vector& dims) { init=normal())); } - - auto probs = named( - softmax(dot(layers.back(), weights.back()) + biases.back()), - "probs" - ); - - auto cost = -mean(sum(y * log(probs), axis=1), axis=0); + + auto scores = named(dot(layers.back(), weights.back()) + biases.back(), + "scores"); + + auto cost = mean(cross_entropy(scores, y), axis=0); + //auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0); auto costreg = named( cost, "cost" ); @@ -142,7 +141,7 @@ int main(int argc, char** argv) { g.forward(BATCH_SIZE); std::vector bResults; - bResults << g["probs"].val(); + bResults << g["scores"].val(); results.insert(results.end(), bResults.begin(), bResults.end()); } diff --git a/src/node_operators.h b/src/node_operators.h index 507d967d..a2cefe70 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -151,7 +151,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { void forward() { // B = softmax(A). - val_ = a_->val(); + thrust::copy(a_->val().begin(), a_->val().end(), val_.begin()); // Safe version of softmax. Softmax(&val_); } @@ -441,4 +441,71 @@ struct DivNodeOp : public BinaryNodeOp { }; +// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise. +struct CrossEntropyNodeOp : public BinaryNodeOp { + template + CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args) + : BinaryNodeOp(a, b, + keywords::shape=newShape(a, b), + args...) { } + + Shape newShape(ChainPtr a, ChainPtr b) { + Shape shape1 = a->shape(); + Shape shape2 = b->shape(); + UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1], + "cross entropy requires dimensions to match"); + shape1[1] = 1; + return shape1; + } + + // We're caching the softmax probabilities here because we'll need them for + // the backward computation. + void forward() { + // C = -dot(B, log(softmax(A))). + if (probs_) { + probs_.set(0.0); + } else { + probs_.allocate(a_->val().shape(), 0.0); + } + thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin()); + Softmax(&probs_); // Safe version of softmax. + Tensor result(a_->val().shape()); + Element(_1 = -_2 * Log(_3), result, b_->val(), probs_); + SumRowwise(result, val_); + } + + // @TODO: In most cases it's wasteful to compute the derivative with respect + // to the second input which is typically an input node in the computation + // graph. In general the backward functions can skip the computation of + // gradients wrt input nodes. + void backward() { + // For each row, the first input derivative is given by adj * (p - y), + // where y is the gold label distribution (e.g. one hot vector) and + // p is the softmax output (probabilities). + // The second input derivative is -adj*log(p). + Tensor result(probs_.shape()); + + // Compute first input derivative. + Element(_1 = _2 - _3, result, probs_, b_->val()); + ScaleRowwise(result, adj_); + Element(_1 += _2, a_->grad(), result); + + // Compute second input derivative. + Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here. + ScaleRowwise(result, adj_); + Element(_1 += _2, b_->grad(), result); + } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + + protected: + Tensor probs_; + +}; + } diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index 1d0f7e2f..fe4e7353 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -227,4 +227,48 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B, return temp; } +Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) { + size_t rows = A.shape()[0]; + size_t cols = A.shape()[1]; + thrust::device_vector d_ones(cols, 1.f); + Float alpha = 1.f; + Float beta = 0.f; + cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha, + A.data(), cols, + thrust::raw_pointer_cast(d_ones.data()), 1, &beta, + result.data(), 1); + return result; } + +Tensor SumRowwise(const Tensor A, Tensor result) { + Tensor temp = SumRowwise(cublasHandle, A, result); + return temp; +} + +// @TODO: replace this by something else when broadcast elementwise operations +// are ready. +__global__ void gScaleRowwise(Float* out, const Float* scalingFactors, + size_t rows, size_t cols) { + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + Float* rowOut = out + j * cols; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int i = tid + threadIdx.x; + if(i < cols) rowOut[i] *= scalingFactors[j]; + } + } + } +} + +void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) { + Float* d_out = Out.data(); + const Float* d_in = ScalingFactors.data(); + int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]); + int threads = std::min(MAX_THREADS, (int)Out.shape()[1]); + gScaleRowwise<<>>(d_out, d_in, + Out.shape()[0], Out.shape()[1]); + cudaStreamSynchronize(0); +} + +} \ No newline at end of file diff --git a/src/tensor_operators.h b/src/tensor_operators.h index 039e6f39..52762272 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -165,4 +165,13 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, Tensor Prod(Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta = 0); +Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result); + +Tensor SumRowwise(const Tensor A, Tensor result); + +__global__ void gScaleRowwise(Float* out, const Float* scalingFactors, + size_t rows, size_t cols); + +void ScaleRowwise(Tensor Out, const Tensor ScalingFactors); + }