This commit is contained in:
Hieu Hoang 2016-09-13 18:33:44 +02:00
parent 8f25a1d4bf
commit 6935334b64

View File

@ -12,6 +12,7 @@ int main(int argc, char** argv) {
using namespace marian;
using namespace keywords;
/*
Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y");
@ -53,7 +54,29 @@ int main(int argc, char** argv) {
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;
*/
Expr x = input(shape={whatevs, 2}, name="X");
Expr y = input(shape={whatevs, 2}, name="Y");
Expr w = param(shape={2, 1}, name="W0");
Expr b = param(shape={1, 1}, name="b0");
Expr n5 = dot(x, w);
Expr n6 = n5 + b;
Expr lr = softmax(n6, axis=1, name="pred");
cerr << "lr=" << lr.Debug() << endl;
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
Tensor tx({4, 2}, 1);
Tensor ty({4, 1}, 1);
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
tx.Load("../examples/xor/train.txt");
ty.Load("../examples/xor/label.txt");
#if 0
hook0(graph);
graph.autodiff();