Fixed bug affecting softmax (was always returning a scalar output).

This commit is contained in:
Andre Martins 2016-09-13 16:06:53 +01:00
parent c05b18071a
commit 6dbe5bb1d0
3 changed files with 62 additions and 24 deletions

View File

@ -134,6 +134,7 @@ inline Expr sum(Expr a, Args ...args) {
else if(ax == 1) {
auto lshape = [n]() -> Shape {
int cols = n->val().shape()[1];
//std::cerr << "Shape will be " << cols << " by 1." << std::endl;
return {cols, 1};
};
Expr one = ones(shape={n->shape()[1], 1},
@ -153,6 +154,20 @@ inline Expr sum(Expr a, Args ...args) {
template <typename ...Args>
inline Expr softmax(Expr a, Args ...args) {
Expr e = exp(a);
#if 0
ChainPtr n = a.node();
auto print_shape = [n]() -> Shape {
std::cerr << "Shape: ";
for (auto val : n->val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
return {1,1};
};
using namespace keywords;
Expr one = ones(shape={1, 1}, lazy_shape=print_shape);
#endif
return e / sum(e, args...);
}
@ -187,4 +202,4 @@ inline Expr mean(Expr a, Args ...args) {
}
}
}
}

View File

@ -118,9 +118,15 @@ struct LogNodeOp : public UnaryNodeOp {
struct ExpNodeOp : public UnaryNodeOp {
template <typename ...Args>
ExpNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
ExpNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a),
args...) { }
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
return shape;
}
void forward() {
Element(_1 = Exp(_2), val_, a_->val());
}
@ -289,4 +295,4 @@ struct DivNodeOp : public BroadcastingNodeOp {
}
};
}
}

View File

@ -12,7 +12,8 @@ int main(int argc, char** argv) {
auto w = param(shape={784, 10}, name="W0");
auto b = param(shape={1, 10}, name="b0");
auto lr = softmax(dot(x, w) + b, axis=1, name="pred");
auto scores = dot(x, w) + b;
auto lr = softmax(scores, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
Tensor tx({500, 784}, 1);
@ -22,28 +23,44 @@ int main(int argc, char** argv) {
y = ty;
graph.forward(500);
std::cerr << "Result: ";
for (auto val : scores.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Result: ";
for (auto val : lr.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Log-likelihood: ";
for (auto val : graph.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
//std::cerr << graph["pred"].val()[0] << std::endl;
//hook0(graph);
//graph.autodiff();
//std::cerr << graph["cost"].val()[0] << std::endl;
#if 0
hook0(graph);
graph.autodiff();
std::cerr << graph["cost"].val()[0] << std::endl;
//hook1(graph);
//for(auto p : graph.params()) {
// auto update = _1 = _1 - alpha * _2;
// Element(update, p.val(), p.grad());
//}
//hook2(graph);
//
//auto opt = adadelta(cost_function=cost,
// eta=0.9, gamma=0.1,
// set_batch=set,
// before_update=before,
// after_update=after,
// set_valid=valid,
// validation_freq=100,
// verbose=1, epochs=3, early_stopping=10);
//opt.run();
for(auto p : graph.params()) {
auto update = _1 = _1 - alpha * _2;
Element(update, p.val(), p.grad());
}
hook2(graph);
auto opt = adadelta(cost_function=cost,
eta=0.9, gamma=0.1,
set_batch=set,
before_update=before,
after_update=after,
set_valid=valid,
validation_freq=100,
verbose=1, epochs=3, early_stopping=10);
opt.run();
#endif
return 0;
}