mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Fixed bug affecting softmax (was always returning a scalar output).
This commit is contained in:
parent
c05b18071a
commit
6dbe5bb1d0
@ -134,6 +134,7 @@ inline Expr sum(Expr a, Args ...args) {
|
||||
else if(ax == 1) {
|
||||
auto lshape = [n]() -> Shape {
|
||||
int cols = n->val().shape()[1];
|
||||
//std::cerr << "Shape will be " << cols << " by 1." << std::endl;
|
||||
return {cols, 1};
|
||||
};
|
||||
Expr one = ones(shape={n->shape()[1], 1},
|
||||
@ -153,6 +154,20 @@ inline Expr sum(Expr a, Args ...args) {
|
||||
template <typename ...Args>
|
||||
inline Expr softmax(Expr a, Args ...args) {
|
||||
Expr e = exp(a);
|
||||
#if 0
|
||||
ChainPtr n = a.node();
|
||||
auto print_shape = [n]() -> Shape {
|
||||
std::cerr << "Shape: ";
|
||||
for (auto val : n->val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
return {1,1};
|
||||
};
|
||||
using namespace keywords;
|
||||
Expr one = ones(shape={1, 1}, lazy_shape=print_shape);
|
||||
#endif
|
||||
|
||||
return e / sum(e, args...);
|
||||
}
|
||||
|
||||
@ -187,4 +202,4 @@ inline Expr mean(Expr a, Args ...args) {
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -118,9 +118,15 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
ExpNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
}
|
||||
@ -289,4 +295,4 @@ struct DivNodeOp : public BroadcastingNodeOp {
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
57
src/test.cu
57
src/test.cu
@ -12,7 +12,8 @@ int main(int argc, char** argv) {
|
||||
auto w = param(shape={784, 10}, name="W0");
|
||||
auto b = param(shape={1, 10}, name="b0");
|
||||
|
||||
auto lr = softmax(dot(x, w) + b, axis=1, name="pred");
|
||||
auto scores = dot(x, w) + b;
|
||||
auto lr = softmax(scores, axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
|
||||
Tensor tx({500, 784}, 1);
|
||||
@ -22,28 +23,44 @@ int main(int argc, char** argv) {
|
||||
y = ty;
|
||||
|
||||
graph.forward(500);
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : scores.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : lr.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Log-likelihood: ";
|
||||
for (auto val : graph.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
|
||||
|
||||
//hook0(graph);
|
||||
//graph.autodiff();
|
||||
//std::cerr << graph["cost"].val()[0] << std::endl;
|
||||
#if 0
|
||||
hook0(graph);
|
||||
graph.autodiff();
|
||||
std::cerr << graph["cost"].val()[0] << std::endl;
|
||||
//hook1(graph);
|
||||
//for(auto p : graph.params()) {
|
||||
// auto update = _1 = _1 - alpha * _2;
|
||||
// Element(update, p.val(), p.grad());
|
||||
//}
|
||||
//hook2(graph);
|
||||
//
|
||||
//auto opt = adadelta(cost_function=cost,
|
||||
// eta=0.9, gamma=0.1,
|
||||
// set_batch=set,
|
||||
// before_update=before,
|
||||
// after_update=after,
|
||||
// set_valid=valid,
|
||||
// validation_freq=100,
|
||||
// verbose=1, epochs=3, early_stopping=10);
|
||||
//opt.run();
|
||||
for(auto p : graph.params()) {
|
||||
auto update = _1 = _1 - alpha * _2;
|
||||
Element(update, p.val(), p.grad());
|
||||
}
|
||||
hook2(graph);
|
||||
|
||||
auto opt = adadelta(cost_function=cost,
|
||||
eta=0.9, gamma=0.1,
|
||||
set_batch=set,
|
||||
before_update=before,
|
||||
after_update=after,
|
||||
set_valid=valid,
|
||||
validation_freq=100,
|
||||
verbose=1, epochs=3, early_stopping=10);
|
||||
opt.run();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user