Fixed backward for fast softmax.

This commit is contained in:
Andre Martins 2016-09-15 13:54:46 +01:00
parent 499faceb8e
commit 6d3f67e955
3 changed files with 4 additions and 19 deletions

View File

@ -153,20 +153,6 @@ inline Expr sum(Expr a, Args ...args) {
template <typename ...Args>
inline Expr softmax(Expr a, Args ...args) {
Expr e = exp(a);
#if 0
ChainPtr n = a.node();
auto print_shape = [n]() -> Shape {
std::cerr << "Shape: ";
for (auto val : n->val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
return {1,1};
};
using namespace keywords;
Expr one = ones(shape={1, 1}, lazy_shape=print_shape);
#endif
return e / sum(e, args...);
}

View File

@ -162,11 +162,10 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
Tensor result = adj_;
Tensor result(adj_.shape());
thrust::copy(adj_.begin(), adj_.end(), result.begin());
SubtractMean(&result, val_);
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
// to sum gradients from different graph parts.
Prod(a_->grad(), adj_, result, false, false, 1.0);
Element(_1 += _2 * _3, a_->grad(), val_, result);
}
};

View File

@ -19,7 +19,7 @@ int main(int argc, char** argv) {
//Expr b = param(shape={1, 2}, name="b0");
std::cerr << "Building model...";
auto predict = softmax(dot(x, w),
auto predict = softmax_fast(dot(x, w),
axis=1, name="pred");
auto graph = -mean(sum(y * log(predict), axis=1),
axis=0, name="cost");