Corrected bug in softmax; implemented cross-entropy node along with a few CUDA functions.

This commit is contained in:
Andre Martins 2016-09-18 15:43:28 +01:00
parent 20c80e4bbc
commit c9dc0cf934
6 changed files with 134 additions and 9 deletions

View File

@ -125,4 +125,8 @@ Expr dot(Expr a, Expr b) {
return Expr(a.graph(), new DotNodeOp(a, b)); return Expr(a.graph(), new DotNodeOp(a, b));
} }
Expr cross_entropy(Expr a, Expr b) {
return Expr(a.graph(), new CrossEntropyNodeOp(a, b));
}
} }

View File

@ -112,4 +112,6 @@ inline Expr mean(Expr a, Args ...args) {
} }
} }
Expr cross_entropy(Expr a, Expr b);
} }

View File

@ -42,13 +42,12 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
init=normal())); init=normal()));
} }
auto probs = named( auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
softmax(dot(layers.back(), weights.back()) + biases.back()), "scores");
"probs"
); auto cost = mean(cross_entropy(scores, y), axis=0);
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
auto costreg = named( auto costreg = named(
cost, "cost" cost, "cost"
); );
@ -142,7 +141,7 @@ int main(int argc, char** argv) {
g.forward(BATCH_SIZE); g.forward(BATCH_SIZE);
std::vector<float> bResults; std::vector<float> bResults;
bResults << g["probs"].val(); bResults << g["scores"].val();
results.insert(results.end(), bResults.begin(), bResults.end()); results.insert(results.end(), bResults.begin(), bResults.end());
} }

View File

@ -151,7 +151,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
void forward() { void forward() {
// B = softmax(A). // B = softmax(A).
val_ = a_->val(); thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax. // Safe version of softmax.
Softmax(&val_); Softmax(&val_);
} }
@ -441,4 +441,71 @@ struct DivNodeOp : public BinaryNodeOp {
}; };
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
struct CrossEntropyNodeOp : public BinaryNodeOp {
template <typename ...Args>
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
"cross entropy requires dimensions to match");
shape1[1] = 1;
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
if (probs_) {
probs_.set(0.0);
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
Tensor result(a_->val().shape());
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
SumRowwise(result, val_);
}
// @TODO: In most cases it's wasteful to compute the derivative with respect
// to the second input which is typically an input node in the computation
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
Tensor result(probs_.shape());
// Compute first input derivative.
Element(_1 = _2 - _3, result, probs_, b_->val());
ScaleRowwise(result, adj_);
Element(_1 += _2, a_->grad(), result);
// Compute second input derivative.
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
ScaleRowwise(result, adj_);
Element(_1 += _2, b_->grad(), result);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
protected:
Tensor probs_;
};
} }

View File

@ -227,4 +227,48 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
return temp; return temp;
} }
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) {
size_t rows = A.shape()[0];
size_t cols = A.shape()[1];
thrust::device_vector<float> d_ones(cols, 1.f);
Float alpha = 1.f;
Float beta = 0.f;
cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha,
A.data(), cols,
thrust::raw_pointer_cast(d_ones.data()), 1, &beta,
result.data(), 1);
return result;
} }
Tensor SumRowwise(const Tensor A, Tensor result) {
Tensor temp = SumRowwise(cublasHandle, A, result);
return temp;
}
// @TODO: replace this by something else when broadcast elementwise operations
// are ready.
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) rowOut[i] *= scalingFactors[j];
}
}
}
}
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
Float* d_out = Out.data();
const Float* d_in = ScalingFactors.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
}

View File

@ -165,4 +165,13 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
Tensor Prod(Tensor C, const Tensor A, const Tensor B, Tensor Prod(Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta = 0); bool transA, bool transB, Float beta = 0);
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result);
Tensor SumRowwise(const Tensor A, Tensor result);
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols);
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
} }