Corrected bug in softmax; implemented cross-entropy node along with a few CUDA functions.

This commit is contained in:
Andre Martins 2016-09-18 15:43:28 +01:00
parent 20c80e4bbc
commit c9dc0cf934
6 changed files with 134 additions and 9 deletions

View File

@ -125,4 +125,8 @@ Expr dot(Expr a, Expr b) {
return Expr(a.graph(), new DotNodeOp(a, b));
}
Expr cross_entropy(Expr a, Expr b) {
return Expr(a.graph(), new CrossEntropyNodeOp(a, b));
}
}

View File

@ -112,4 +112,6 @@ inline Expr mean(Expr a, Args ...args) {
}
}
Expr cross_entropy(Expr a, Expr b);
}

View File

@ -42,13 +42,12 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
init=normal()));
}
auto probs = named(
softmax(dot(layers.back(), weights.back()) + biases.back()),
"probs"
);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
auto cost = mean(cross_entropy(scores, y), axis=0);
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
auto costreg = named(
cost, "cost"
);
@ -142,7 +141,7 @@ int main(int argc, char** argv) {
g.forward(BATCH_SIZE);
std::vector<float> bResults;
bResults << g["probs"].val();
bResults << g["scores"].val();
results.insert(results.end(), bResults.begin(), bResults.end());
}

View File

@ -151,7 +151,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
void forward() {
// B = softmax(A).
val_ = a_->val();
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax.
Softmax(&val_);
}
@ -441,4 +441,71 @@ struct DivNodeOp : public BinaryNodeOp {
};
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
struct CrossEntropyNodeOp : public BinaryNodeOp {
template <typename ...Args>
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
"cross entropy requires dimensions to match");
shape1[1] = 1;
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
if (probs_) {
probs_.set(0.0);
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
Tensor result(a_->val().shape());
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
SumRowwise(result, val_);
}
// @TODO: In most cases it's wasteful to compute the derivative with respect
// to the second input which is typically an input node in the computation
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
Tensor result(probs_.shape());
// Compute first input derivative.
Element(_1 = _2 - _3, result, probs_, b_->val());
ScaleRowwise(result, adj_);
Element(_1 += _2, a_->grad(), result);
// Compute second input derivative.
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
ScaleRowwise(result, adj_);
Element(_1 += _2, b_->grad(), result);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
protected:
Tensor probs_;
};
}

View File

@ -227,4 +227,48 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
return temp;
}
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) {
size_t rows = A.shape()[0];
size_t cols = A.shape()[1];
thrust::device_vector<float> d_ones(cols, 1.f);
Float alpha = 1.f;
Float beta = 0.f;
cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha,
A.data(), cols,
thrust::raw_pointer_cast(d_ones.data()), 1, &beta,
result.data(), 1);
return result;
}
Tensor SumRowwise(const Tensor A, Tensor result) {
Tensor temp = SumRowwise(cublasHandle, A, result);
return temp;
}
// @TODO: replace this by something else when broadcast elementwise operations
// are ready.
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) rowOut[i] *= scalingFactors[j];
}
}
}
}
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
Float* d_out = Out.data();
const Float* d_in = ScalingFactors.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
}

View File

@ -165,4 +165,13 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta = 0);
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result);
Tensor SumRowwise(const Tensor A, Tensor result);
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols);
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
}